diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 547b5ea48a555..41c3c3a89fa72 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2090,8 +2090,8 @@ setMethod("selectExpr", #' #' @param x a SparkDataFrame. #' @param colName a column name. -#' @param col a Column expression (which must refer only to this DataFrame), or an atomic vector in -#' the length of 1 as literal value. +#' @param col a Column expression (which must refer only to this SparkDataFrame), or an atomic +#' vector in the length of 1 as literal value. #' @return A SparkDataFrame with the new column added or the existing column replaced. #' @family SparkDataFrame functions #' @aliases withColumn,SparkDataFrame,character-method diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index 9cac7d00cc6b6..0bc571874f07c 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -171,7 +171,9 @@ private class DownloadCallback implements StreamCallback { @Override public void onData(String streamId, ByteBuffer buf) throws IOException { - channel.write(buf); + while (buf.hasRemaining()) { + channel.write(buf); + } } @Override diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index e9539dc73f6fa..55a9122cf9026 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -244,7 +244,10 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { for (i <- 0 until testOutputCopies) { // Shift values by i so that they're different in the output val alteredOutput = testOutput.map(b => (b + i).toByte) - channel.write(ByteBuffer.wrap(alteredOutput)) + val buffer = ByteBuffer.wrap(alteredOutput) + while (buffer.hasRemaining) { + channel.write(buffer) + } } channel.close() file.close() diff --git a/docs/ml-guide.md b/docs/ml-guide.md index b957445579ffd..702bcf748fc74 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -108,7 +108,13 @@ and the migration guide below will explain all changes between releases. ### Breaking changes -There are no breaking changes. +* The class and trait hierarchy for logistic regression model summaries was changed to be cleaner +and better accommodate the addition of the multi-class summary. This is a breaking change for user +code that casts a `LogisticRegressionTrainingSummary` to a +` BinaryLogisticRegressionTrainingSummary`. Users should instead use the `model.binarySummary` +method. See [SPARK-17139](https://issues.apache.org/jira/browse/SPARK-17139) for more detail +(_note_ this is an `Experimental` API). This _does not_ affect the Python `summary` method, which +will still work correctly for both multinomial and binary cases. ### Deprecations and changes of behavior @@ -123,8 +129,19 @@ new [`OneHotEncoderEstimator`](ml-features.html#onehotencoderestimator) **Changes of behavior** * [SPARK-21027](https://issues.apache.org/jira/browse/SPARK-21027): - We are now setting the default parallelism used in `OneVsRest` to be 1 (i.e. serial). In 2.2 and + The default parallelism used in `OneVsRest` is now set to 1 (i.e. serial). In `2.2` and earlier versions, the level of parallelism was set to the default threadpool size in Scala. +* [SPARK-22156](https://issues.apache.org/jira/browse/SPARK-22156): + The learning rate update for `Word2Vec` was incorrect when `numIterations` was set greater than + `1`. This will cause training results to be different between `2.3` and earlier versions. +* [SPARK-21681](https://issues.apache.org/jira/browse/SPARK-21681): + Fixed an edge case bug in multinomial logistic regression that resulted in incorrect coefficients + when some features had zero variance. +* [SPARK-16957](https://issues.apache.org/jira/browse/SPARK-16957): + Tree algorithms now use mid-points for split values. This may change results from model training. +* [SPARK-14657](https://issues.apache.org/jira/browse/SPARK-14657): + Fixed an issue where the features generated by `RFormula` without an intercept were inconsistent + with the output in R. This may change results from model training in this scenario. ## Previous Spark versions diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala index 8c733426b256f..41c443bc12120 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRo import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.kafka010.KafkaSource.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.types.UTF8String diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala index c82154cfbad7f..8d41c0da2b133 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.kafka010 import org.apache.kafka.common.TopicPartition import org.apache.spark.sql.execution.streaming.{Offset, SerializedOffset} -import org.apache.spark.sql.sources.v2.streaming.reader.{Offset => OffsetV2, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2, PartitionOffset} /** * An [[Offset]] for the [[KafkaSource]]. This one tracks all partitions of subscribed topics and diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 85e96b6783327..694ca76e24964 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -31,8 +31,9 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SparkSessio import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.sources._ import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, StreamWriteSupport} -import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter +import org.apache.spark.sql.sources.v2.reader.ContinuousReadSupport +import org.apache.spark.sql.sources.v2.writer.StreamWriteSupport +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala index a24efdefa4464..9307bfc001c03 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala @@ -22,8 +22,8 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.kafka010.KafkaWriter.validateQuery -import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter import org.apache.spark.sql.sources.v2.writer._ +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.types.StructType /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 1155ea5fdd85b..22e7b8bbf1ff5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -74,7 +74,7 @@ private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol with * @group param */ @Since("2.3.0") - final override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", + override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "How to handle invalid data (unseen or NULL values) in features and label column of string " + "type. Options are 'skip' (filter out rows with invalid data), error (throw an error), " + "or 'keep' (put invalid data in a special additional bucket, at index numLabels).", diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index a5873d03b4161..6d3fe7a6c748c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -645,7 +645,7 @@ class LinearRegressionModel private[ml] ( extends RegressionModel[Vector, LinearRegressionModel] with LinearRegressionParams with MLWritable { - def this(uid: String, coefficients: Vector, intercept: Double) = + private[ml] def this(uid: String, coefficients: Vector, intercept: Double) = this(uid, coefficients, intercept, 1.0) private var trainingSummary: Option[LinearRegressionTrainingSummary] = None diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 055b2c4a0ffec..1496cba91b90e 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1829,7 +1829,7 @@ def withColumn(self, colName, col): Returns a new :class:`DataFrame` by adding a column or replacing the existing column that has the same name. - The column expression must be an expression over this dataframe; attempting to add + The column expression must be an expression over this DataFrame; attempting to add a column from some other dataframe will raise an error. :param colName: string, name of the new column. diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 6c84023c43fb6..1ed04298bc899 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -213,7 +213,12 @@ def __init__(self, sparkContext, jsparkSession=None): self._jsc = self._sc._jsc self._jvm = self._sc._jvm if jsparkSession is None: - jsparkSession = self._jvm.SparkSession(self._jsc.sc()) + if self._jvm.SparkSession.getDefaultSession().isDefined() \ + and not self._jvm.SparkSession.getDefaultSession().get() \ + .sparkContext().isStopped(): + jsparkSession = self._jvm.SparkSession.getDefaultSession().get() + else: + jsparkSession = self._jvm.SparkSession(self._jsc.sc()) self._jsparkSession = jsparkSession self._jwrapped = self._jsparkSession.sqlContext() self._wrapped = SQLContext(self._sc, self, self._jwrapped) @@ -225,6 +230,7 @@ def __init__(self, sparkContext, jsparkSession=None): if SparkSession._instantiatedSession is None \ or SparkSession._instantiatedSession._sc._jsc is None: SparkSession._instantiatedSession = self + self._jvm.SparkSession.setDefaultSession(self._jsparkSession) def _repr_html_(self): return """ @@ -759,6 +765,8 @@ def stop(self): """Stop the underlying :class:`SparkContext`. """ self._sc.stop() + # We should clean the default session up. See SPARK-23228. + self._jvm.SparkSession.clearDefaultSession() SparkSession._instantiatedSession = None @since(2.0) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index dc80870d3cd9f..b27363023ae77 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -69,7 +69,7 @@ from pyspark.sql.types import _array_signed_int_typecode_ctype_mappings, _array_type_mappings from pyspark.sql.types import _array_unsigned_int_typecode_ctype_mappings from pyspark.sql.types import _merge_type -from pyspark.tests import QuietTest, ReusedPySparkTestCase, SparkSubmitTests +from pyspark.tests import QuietTest, ReusedPySparkTestCase, PySparkTestCase, SparkSubmitTests from pyspark.sql.functions import UserDefinedFunction, sha2, lit from pyspark.sql.window import Window from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException @@ -2925,6 +2925,32 @@ def test_sparksession_with_stopped_sparkcontext(self): sc.stop() +class SparkSessionTests(PySparkTestCase): + + # This test is separate because it's closely related with session's start and stop. + # See SPARK-23228. + def test_set_jvm_default_session(self): + spark = SparkSession.builder.getOrCreate() + try: + self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isDefined()) + finally: + spark.stop() + self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isEmpty()) + + def test_jvm_default_session_already_set(self): + # Here, we assume there is the default session already set in JVM. + jsession = self.sc._jvm.SparkSession(self.sc._jsc.sc()) + self.sc._jvm.SparkSession.setDefaultSession(jsession) + + spark = SparkSession.builder.getOrCreate() + try: + self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isDefined()) + # The session should be the same with the exiting one. + self.assertTrue(jsession.equals(spark._jvm.SparkSession.getDefaultSession().get())) + finally: + spark.stop() + + class UDFInitializationTests(unittest.TestCase): def tearDown(self): if SparkSession._instantiatedSession is not None: @@ -4504,19 +4530,19 @@ def test_unsupported_types(self): from pyspark.sql.functions import pandas_udf, PandasUDFType with QuietTest(self.sc): - with self.assertRaisesRegex(NotImplementedError, 'not supported'): + with self.assertRaisesRegexp(NotImplementedError, 'not supported'): @pandas_udf(ArrayType(DoubleType()), PandasUDFType.GROUPED_AGG) def mean_and_std_udf(v): return [v.mean(), v.std()] with QuietTest(self.sc): - with self.assertRaisesRegex(NotImplementedError, 'not supported'): + with self.assertRaisesRegexp(NotImplementedError, 'not supported'): @pandas_udf('mean double, std double', PandasUDFType.GROUPED_AGG) def mean_and_std_udf(v): return v.mean(), v.std() with QuietTest(self.sc): - with self.assertRaisesRegex(NotImplementedError, 'not supported'): + with self.assertRaisesRegexp(NotImplementedError, 'not supported'): @pandas_udf(MapType(DoubleType(), DoubleType()), PandasUDFType.GROUPED_AGG) def mean_and_std_udf(v): return {v.mean(): v.std()} diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 4d5e3bb043671..2f88feb0f1fdf 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -30,6 +30,7 @@ import scala.util.control.NonFatal import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.security.UserGroupInformation +import org.apache.hadoop.util.StringUtils import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration @@ -718,7 +719,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends logError("User class threw exception: " + cause, cause) finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_EXCEPTION_USER_CLASS, - "User class threw exception: " + cause) + "User class threw exception: " + StringUtils.stringifyException(cause)) } sparkContextPromise.tryFailure(e.getCause()) } finally { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 91cb0365a0856..7848f88bda1c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1493,7 +1493,7 @@ class Analyzer( // to push down this ordering expression and can reference the original aggregate // expression instead. val needsPushDown = ArrayBuffer.empty[NamedExpression] - val evaluatedOrderings = resolvedAliasedOrdering.zip(sortOrder).map { + val evaluatedOrderings = resolvedAliasedOrdering.zip(unresolvedSortOrders).map { case (evaluated, order) => val index = originalAggExprs.indexWhere { case Alias(child, _) => child semanticEquals evaluated.child @@ -2038,7 +2038,11 @@ class Analyzer( WindowExpression(wf, s.copy(frameSpecification = wf.frame)) case we @ WindowExpression(e, s @ WindowSpecDefinition(_, o, UnspecifiedFrame)) if e.resolved => - val frame = SpecifiedWindowFrame.defaultWindowFrame(o.nonEmpty, acceptWindowFrame = true) + val frame = if (o.nonEmpty) { + SpecifiedWindowFrame(RangeFrame, UnboundedPreceding, CurrentRow) + } else { + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing) + } we.copy(windowSpec = s.copy(frameSpecification = frame)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index dd13d9a3bba51..78895f1c2f6f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -265,27 +265,6 @@ case class SpecifiedWindowFrame( } } -object SpecifiedWindowFrame { - /** - * @param hasOrderSpecification If the window spec has order by expressions. - * @param acceptWindowFrame If the window function accepts user-specified frame. - * @return the default window frame. - */ - def defaultWindowFrame( - hasOrderSpecification: Boolean, - acceptWindowFrame: Boolean): SpecifiedWindowFrame = { - if (hasOrderSpecification && acceptWindowFrame) { - // If order spec is defined and the window function supports user specified window frames, - // the default frame is RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW. - SpecifiedWindowFrame(RangeFrame, UnboundedPreceding, CurrentRow) - } else { - // Otherwise, the default frame is - // ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING. - SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing) - } - } -} - case class UnresolvedWindowExpression( child: Expression, windowSpec: WindowSpecReference) extends UnaryExpression with Unevaluable { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala index 89bfcee078fba..45edf266bbce4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala @@ -46,18 +46,27 @@ object ReplaceExceptWithFilter extends Rule[LogicalPlan] { } plan.transform { - case Except(left, right) if isEligible(left, right) => - Distinct(Filter(Not(transformCondition(left, skipProject(right))), left)) + case e @ Except(left, right) if isEligible(left, right) => + val newCondition = transformCondition(left, skipProject(right)) + newCondition.map { c => + Distinct(Filter(Not(c), left)) + }.getOrElse { + e + } } } - private def transformCondition(left: LogicalPlan, right: LogicalPlan): Expression = { + private def transformCondition(left: LogicalPlan, right: LogicalPlan): Option[Expression] = { val filterCondition = InferFiltersFromConstraints(combineFilters(right)).asInstanceOf[Filter].condition val attributeNameMap: Map[String, Attribute] = left.output.map(x => (x.name, x)).toMap - filterCondition.transform { case a : AttributeReference => attributeNameMap(a.name) } + if (filterCondition.references.forall(r => attributeNameMap.contains(r.name))) { + Some(filterCondition.transform { case a: AttributeReference => attributeNameMap(a.name) }) + } else { + None + } } // TODO: This can be further extended in the future. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 7394a0d7cf983..90654e67457e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -375,6 +375,12 @@ object SQLConf { .booleanConf .createWithDefault(true) + val PARQUET_VECTORIZED_READER_BATCH_SIZE = buildConf("spark.sql.parquet.columnarReaderBatchSize") + .doc("The number of rows to include in a parquet vectorized reader batch. The number should " + + "be carefully chosen to minimize overhead and avoid OOMs in reading data.") + .intConf + .createWithDefault(4096) + val ORC_COMPRESSION = buildConf("spark.sql.orc.compression.codec") .doc("Sets the compression codec used when writing ORC files. If either `compression` or " + "`orc.compress` is specified in the table-specific options/properties, the precedence " + @@ -398,6 +404,12 @@ object SQLConf { .booleanConf .createWithDefault(true) + val ORC_VECTORIZED_READER_BATCH_SIZE = buildConf("spark.sql.orc.columnarReaderBatchSize") + .doc("The number of rows to include in a orc vectorized reader batch. The number should " + + "be carefully chosen to minimize overhead and avoid OOMs in reading data.") + .intConf + .createWithDefault(4096) + val ORC_COPY_BATCH_TO_SPARK = buildConf("spark.sql.orc.copyBatchToSpark") .doc("Whether or not to copy the ORC columnar batch to Spark columnar batch in the " + "vectorized ORC reader.") @@ -1250,10 +1262,14 @@ class SQLConf extends Serializable with Logging { def orcVectorizedReaderEnabled: Boolean = getConf(ORC_VECTORIZED_READER_ENABLED) + def orcVectorizedReaderBatchSize: Int = getConf(ORC_VECTORIZED_READER_BATCH_SIZE) + def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION) def parquetVectorizedReaderEnabled: Boolean = getConf(PARQUET_VECTORIZED_READER_ENABLED) + def parquetVectorizedReaderBatchSize: Int = getConf(PARQUET_VECTORIZED_READER_BATCH_SIZE) + def columnBatchSize: Int = getConf(COLUMN_BATCH_SIZE) def numShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala index e9701ffd2c54b..52dc2e9fb076c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala @@ -168,6 +168,21 @@ class ReplaceOperatorSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("replace Except with Filter when only right filter can be applied to the left") { + val table = LocalRelation(Seq('a.int, 'b.int)) + val left = table.where('b < 1).select('a).as("left") + val right = table.where('b < 3).select('a).as("right") + + val query = Except(left, right) + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = + Aggregate(left.output, right.output, + Join(left, right, LeftAnti, Option($"left.a" <=> $"right.a"))).analyze + + comparePlans(optimized, correctAnswer) + } + test("replace Distinct with Aggregate") { val input = LocalRelation('a.int, 'b.int) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java index 5078bc7922ee2..c8add4c9f486c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java @@ -25,6 +25,7 @@ import org.apache.spark.sql.types.Decimal; import org.apache.spark.sql.types.TimestampType; import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarMap; import org.apache.spark.unsafe.types.UTF8String; /** @@ -77,6 +78,11 @@ public void close() { } + @Override + public boolean hasNull() { + return !baseData.noNulls; + } + @Override public int numNulls() { if (baseData.isRepeating) { @@ -172,6 +178,11 @@ public ColumnarArray getArray(int rowId) { throw new UnsupportedOperationException(); } + @Override + public ColumnarMap getMap(int rowId) { + throw new UnsupportedOperationException(); + } + @Override public org.apache.spark.sql.vectorized.ColumnVector getChild(int ordinal) { throw new UnsupportedOperationException(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java index 5e7cad470e1d1..dcebdc39f0aa2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java @@ -49,8 +49,9 @@ * After creating, `initialize` and `initBatch` should be called sequentially. */ public class OrcColumnarBatchReader extends RecordReader { - // TODO: make this configurable. - private static final int CAPACITY = 4 * 1024; + + // The capacity of vectorized batch. + private int capacity; // Vectorized ORC Row Batch private VectorizedRowBatch batch; @@ -81,9 +82,10 @@ public class OrcColumnarBatchReader extends RecordReader { // Whether or not to copy the ORC columnar batch to Spark columnar batch. private final boolean copyToSpark; - public OrcColumnarBatchReader(boolean useOffHeap, boolean copyToSpark) { + public OrcColumnarBatchReader(boolean useOffHeap, boolean copyToSpark, int capacity) { MEMORY_MODE = useOffHeap ? MemoryMode.OFF_HEAP : MemoryMode.ON_HEAP; this.copyToSpark = copyToSpark; + this.capacity = capacity; } @@ -148,7 +150,7 @@ public void initBatch( StructField[] requiredFields, StructType partitionSchema, InternalRow partitionValues) { - batch = orcSchema.createRowBatch(CAPACITY); + batch = orcSchema.createRowBatch(capacity); assert(!batch.selectedInUse); // `selectedInUse` should be initialized with `false`. this.requiredFields = requiredFields; @@ -162,15 +164,15 @@ public void initBatch( if (copyToSpark) { if (MEMORY_MODE == MemoryMode.OFF_HEAP) { - columnVectors = OffHeapColumnVector.allocateColumns(CAPACITY, resultSchema); + columnVectors = OffHeapColumnVector.allocateColumns(capacity, resultSchema); } else { - columnVectors = OnHeapColumnVector.allocateColumns(CAPACITY, resultSchema); + columnVectors = OnHeapColumnVector.allocateColumns(capacity, resultSchema); } // Initialize the missing columns once. for (int i = 0; i < requiredFields.length; i++) { if (requestedColIds[i] == -1) { - columnVectors[i].putNulls(0, CAPACITY); + columnVectors[i].putNulls(0, capacity); columnVectors[i].setIsConstant(); } } @@ -193,8 +195,8 @@ public void initBatch( int colId = requestedColIds[i]; // Initialize the missing columns once. if (colId == -1) { - OnHeapColumnVector missingCol = new OnHeapColumnVector(CAPACITY, dt); - missingCol.putNulls(0, CAPACITY); + OnHeapColumnVector missingCol = new OnHeapColumnVector(capacity, dt); + missingCol.putNulls(0, capacity); missingCol.setIsConstant(); orcVectorWrappers[i] = missingCol; } else { @@ -206,7 +208,7 @@ public void initBatch( int partitionIdx = requiredFields.length; for (int i = 0; i < partitionValues.numFields(); i++) { DataType dt = partitionSchema.fields()[i].dataType(); - OnHeapColumnVector partitionCol = new OnHeapColumnVector(CAPACITY, dt); + OnHeapColumnVector partitionCol = new OnHeapColumnVector(capacity, dt); ColumnVectorUtils.populate(partitionCol, partitionValues, i); partitionCol.setIsConstant(); orcVectorWrappers[partitionIdx + i] = partitionCol; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java index bb1b23611a7d7..5934a23db8af1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java @@ -50,8 +50,9 @@ * TODO: make this always return ColumnarBatches. */ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBase { - // TODO: make this configurable. - private static final int CAPACITY = 4 * 1024; + + // The capacity of vectorized batch. + private int capacity; /** * Batch of rows that we assemble and the current index we've returned. Every time this @@ -115,13 +116,10 @@ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBa */ private final MemoryMode MEMORY_MODE; - public VectorizedParquetRecordReader(TimeZone convertTz, boolean useOffHeap) { + public VectorizedParquetRecordReader(TimeZone convertTz, boolean useOffHeap, int capacity) { this.convertTz = convertTz; MEMORY_MODE = useOffHeap ? MemoryMode.OFF_HEAP : MemoryMode.ON_HEAP; - } - - public VectorizedParquetRecordReader(boolean useOffHeap) { - this(null, useOffHeap); + this.capacity = capacity; } /** @@ -199,9 +197,9 @@ private void initBatch( } if (memMode == MemoryMode.OFF_HEAP) { - columnVectors = OffHeapColumnVector.allocateColumns(CAPACITY, batchSchema); + columnVectors = OffHeapColumnVector.allocateColumns(capacity, batchSchema); } else { - columnVectors = OnHeapColumnVector.allocateColumns(CAPACITY, batchSchema); + columnVectors = OnHeapColumnVector.allocateColumns(capacity, batchSchema); } columnarBatch = new ColumnarBatch(columnVectors); if (partitionColumns != null) { @@ -215,7 +213,7 @@ private void initBatch( // Initialize missing columns with nulls. for (int i = 0; i < missingColumns.length; i++) { if (missingColumns[i]) { - columnVectors[i].putNulls(0, CAPACITY); + columnVectors[i].putNulls(0, capacity); columnVectors[i].setIsConstant(); } } @@ -257,7 +255,7 @@ public boolean nextBatch() throws IOException { if (rowsReturned >= totalRowCount) return false; checkEndOfRowGroup(); - int num = (int) Math.min((long) CAPACITY, totalCountLoadedSoFar - rowsReturned); + int num = (int) Math.min((long) capacity, totalCountLoadedSoFar - rowsReturned); for (int i = 0; i < columnReaders.length; ++i) { if (columnReaders[i] == null) continue; columnReaders[i].readBatch(num, columnVectors[i]); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java index a2853bbadc92b..829f3ce750fe6 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java @@ -20,8 +20,10 @@ import java.math.BigInteger; import java.nio.charset.StandardCharsets; import java.sql.Date; +import java.util.HashMap; import java.util.Iterator; import java.util.List; +import java.util.Map; import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.Row; @@ -30,6 +32,7 @@ import org.apache.spark.sql.types.*; import org.apache.spark.sql.vectorized.ColumnarArray; import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.apache.spark.sql.vectorized.ColumnarMap; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -109,6 +112,18 @@ public static int[] toJavaIntArray(ColumnarArray array) { return array.toIntArray(); } + public static Map toJavaIntMap(ColumnarMap map) { + int[] keys = toJavaIntArray(map.keyArray()); + int[] values = toJavaIntArray(map.valueArray()); + assert keys.length == values.length; + + Map result = new HashMap<>(); + for (int i = 0; i < keys.length; i++) { + result.put(keys[i], values[i]); + } + return result; + } + private static void appendValue(WritableColumnVector dst, DataType t, Object o) { if (o == null) { if (t instanceof CalendarIntervalType) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java index 2bab095d4d951..307c19032dee5 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java @@ -21,10 +21,10 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; -import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.*; import org.apache.spark.sql.vectorized.ColumnarArray; import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.apache.spark.sql.vectorized.ColumnarMap; import org.apache.spark.sql.vectorized.ColumnarRow; import org.apache.spark.sql.vectorized.ColumnVector; import org.apache.spark.unsafe.types.CalendarInterval; @@ -146,9 +146,7 @@ public byte[] getBinary(int ordinal) { @Override public CalendarInterval getInterval(int ordinal) { if (columns[ordinal].isNullAt(rowId)) return null; - final int months = columns[ordinal].getChild(0).getInt(rowId); - final long microseconds = columns[ordinal].getChild(1).getLong(rowId); - return new CalendarInterval(months, microseconds); + return columns[ordinal].getInterval(rowId); } @Override @@ -164,8 +162,9 @@ public ColumnarArray getArray(int ordinal) { } @Override - public MapData getMap(int ordinal) { - throw new UnsupportedOperationException(); + public ColumnarMap getMap(int ordinal) { + if (columns[ordinal].isNullAt(rowId)) return null; + return columns[ordinal].getMap(rowId); } @Override diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index 1c45b846790b6..754c26579ff08 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -60,7 +60,7 @@ public static OffHeapColumnVector[] allocateColumns(int capacity, StructField[] private long nulls; private long data; - // Set iff the type is array. + // Only set if type is Array or Map. private long lengthData; private long offsetData; @@ -123,7 +123,7 @@ public void putNulls(int rowId, int count) { @Override public void putNotNulls(int rowId, int count) { - if (numNulls == 0) return; + if (!hasNull()) return; long offset = nulls + rowId; for (int i = 0; i < count; ++i, ++offset) { Platform.putByte(null, offset, (byte) 0); @@ -530,7 +530,7 @@ public int putByteArray(int rowId, byte[] value, int offset, int length) { @Override protected void reserveInternal(int newCapacity) { int oldCapacity = (nulls == 0L) ? 0 : capacity; - if (isArray()) { + if (isArray() || type instanceof MapType) { this.lengthData = Platform.reallocateMemory(lengthData, oldCapacity * 4, newCapacity * 4); this.offsetData = diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 1d538fe4181b7..23dcc104e67c4 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -69,7 +69,7 @@ public static OnHeapColumnVector[] allocateColumns(int capacity, StructField[] f private float[] floatData; private double[] doubleData; - // Only set if type is Array. + // Only set if type is Array or Map. private int[] arrayLengths; private int[] arrayOffsets; @@ -119,7 +119,7 @@ public void putNulls(int rowId, int count) { @Override public void putNotNulls(int rowId, int count) { - if (numNulls == 0) return; + if (!hasNull()) return; for (int i = 0; i < count; ++i) { nulls[rowId + i] = (byte)0; } @@ -503,7 +503,7 @@ public int putByteArray(int rowId, byte[] value, int offset, int length) { // Spilt this function out since it is the slow path. @Override protected void reserveInternal(int newCapacity) { - if (isArray()) { + if (isArray() || type instanceof MapType) { int[] newLengths = new int[newCapacity]; int[] newOffsets = new int[newCapacity]; if (this.arrayLengths != null) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index a8ec8ef2aadf8..9d447cdc79063 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -25,6 +25,7 @@ import org.apache.spark.sql.types.*; import org.apache.spark.sql.vectorized.ColumnVector; import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarMap; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.types.UTF8String; @@ -59,8 +60,8 @@ public void reset() { elementsAppended = 0; if (numNulls > 0) { putNotNulls(0, capacity); + numNulls = 0; } - numNulls = 0; } @Override @@ -97,11 +98,19 @@ public void reserve(int requiredCapacity) { private void throwUnsupportedException(int requiredCapacity, Throwable cause) { String message = "Cannot reserve additional contiguous bytes in the vectorized reader " + "(requested = " + requiredCapacity + " bytes). As a workaround, you can disable the " + - "vectorized reader by setting " + SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key() + - " to false."; + "vectorized reader, or increase the vectorized reader batch size. For parquet file " + + "format, refer to " + SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key() + " and " + + SQLConf.PARQUET_VECTORIZED_READER_BATCH_SIZE().key() + "; for orc file format, refer to " + + SQLConf.ORC_VECTORIZED_READER_ENABLED().key() + " and " + + SQLConf.ORC_VECTORIZED_READER_BATCH_SIZE().key() + "."; throw new RuntimeException(message, cause); } + @Override + public boolean hasNull() { + return numNulls > 0; + } + @Override public int numNulls() { return numNulls; } @@ -607,6 +616,13 @@ public final ColumnarArray getArray(int rowId) { return new ColumnarArray(arrayData(), getArrayOffset(rowId), getArrayLength(rowId)); } + // `WritableColumnVector` puts the key array in the first child column vector, value array in the + // second child column vector, and puts the offsets and lengths in the current column vector. + @Override + public final ColumnarMap getMap(int rowId) { + return new ColumnarMap(getChild(0), getChild(1), getArrayOffset(rowId), getArrayLength(rowId)); + } + public WritableColumnVector arrayData() { return childColumns[0]; } @@ -700,6 +716,11 @@ protected WritableColumnVector(int capacity, DataType type) { for (int i = 0; i < childColumns.length; ++i) { this.childColumns[i] = reserveNewColumn(capacity, st.fields()[i].dataType()); } + } else if (type instanceof MapType) { + MapType mapType = (MapType) type; + this.childColumns = new WritableColumnVector[2]; + this.childColumns[0] = reserveNewColumn(capacity, mapType.keyType()); + this.childColumns[1] = reserveNewColumn(capacity, mapType.valueType()); } else if (type instanceof CalendarIntervalType) { // Two columns. Months as int. Microseconds as Long. this.childColumns = new WritableColumnVector[2]; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousReadSupport.java similarity index 94% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousReadSupport.java index f79424e036a52..0c1d5d1a9577a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousReadSupport.java @@ -15,14 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.streaming; +package org.apache.spark.sql.sources.v2.reader; import java.util.Optional; import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.sources.v2.DataSourceV2; import org.apache.spark.sql.sources.v2.DataSourceOptions; -import org.apache.spark.sql.sources.v2.streaming.reader.ContinuousReader; +import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader; import org.apache.spark.sql.types.StructType; /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/MicroBatchReadSupport.java similarity index 95% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchReadSupport.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/MicroBatchReadSupport.java index 22660e42ad850..5e8f0c0dafdcf 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/MicroBatchReadSupport.java @@ -15,14 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.streaming; +package org.apache.spark.sql.sources.v2.reader; import java.util.Optional; import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.streaming.reader.MicroBatchReader; +import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader; import org.apache.spark.sql.types.StructType; /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java index a2383a9d7d680..5405a916951b8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.sources.v2.reader; import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.reader.partitioning.Partitioning; /** * A mix in interface for {@link DataSourceReader}. Data source readers can implement this diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ClusteredDistribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java similarity index 92% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ClusteredDistribution.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java index 27905e325df87..2d0ee50212b56 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ClusteredDistribution.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java @@ -15,9 +15,10 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.reader; +package org.apache.spark.sql.sources.v2.reader.partitioning; import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.reader.DataReader; /** * A concrete implementation of {@link Distribution}. Represents a distribution where records that diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Distribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java similarity index 93% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Distribution.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java index b37562167d9ef..f6b111fdf220d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Distribution.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java @@ -15,9 +15,10 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.reader; +package org.apache.spark.sql.sources.v2.reader.partitioning; import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.reader.DataReader; /** * An interface to represent data distribution requirement, which specifies how the records should diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Partitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java similarity index 90% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Partitioning.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java index 5e334d13a1215..309d9e5de0a0f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Partitioning.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java @@ -15,9 +15,11 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.reader; +package org.apache.spark.sql.sources.v2.reader.partitioning; import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.reader.DataReaderFactory; +import org.apache.spark.sql.sources.v2.reader.SupportsReportPartitioning; /** * An interface to represent the output data partitioning for a data source, which is returned by diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousDataReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousDataReader.java similarity index 96% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousDataReader.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousDataReader.java index 3f13a4dbf5793..47d26440841fd 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousDataReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousDataReader.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.streaming.reader; +package org.apache.spark.sql.sources.v2.reader.streaming; import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.sources.v2.reader.DataReader; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java similarity index 98% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java index 6e5177ee83a62..d1d1e7ffd1dd4 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.streaming.reader; +package org.apache.spark.sql.sources.v2.reader.streaming; import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.execution.streaming.BaseStreamingSource; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java similarity index 98% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java index fcec446d892f5..67ebde30d61a9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.streaming.reader; +package org.apache.spark.sql.sources.v2.reader.streaming; import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.sources.v2.reader.DataSourceReader; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/Offset.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java similarity index 97% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/Offset.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java index abba3e7188b13..e41c0351edc82 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/Offset.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.streaming.reader; +package org.apache.spark.sql.sources.v2.reader.streaming; import org.apache.spark.annotation.InterfaceStability; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/PartitionOffset.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/PartitionOffset.java similarity index 95% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/PartitionOffset.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/PartitionOffset.java index 4688b85f49f5f..383e73db6762b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/PartitionOffset.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/PartitionOffset.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.streaming.reader; +package org.apache.spark.sql.sources.v2.reader.streaming; import java.io.Serializable; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java index d89d27d0e5b1b..52324b3792b8a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java @@ -28,7 +28,7 @@ /** * A data source writer that is returned by * {@link WriteSupport#createWriter(String, StructType, SaveMode, DataSourceOptions)}/ - * {@link org.apache.spark.sql.sources.v2.streaming.StreamWriteSupport#createStreamWriter( + * {@link StreamWriteSupport#createStreamWriter( * String, StructType, OutputMode, DataSourceOptions)}. * It can mix in various writing optimization interfaces to speed up the data saving. The actual * writing logic is delegated to {@link DataWriter}. @@ -62,6 +62,14 @@ public interface DataSourceWriter { */ DataWriterFactory createWriterFactory(); + /** + * Handles a commit message on receiving from a successful data writer. + * + * If this method fails (by throwing an exception), this writing job is considered to to have been + * failed, and {@link #abort(WriterCommitMessage[])} would be called. + */ + default void onDataWriterCommit(WriterCommitMessage message) {} + /** * Commits this writing job with a list of commit messages. The commit messages are collected from * successful data writers and are produced by {@link DataWriter#commit()}. @@ -78,8 +86,10 @@ public interface DataSourceWriter { void commit(WriterCommitMessage[] messages); /** - * Aborts this writing job because some data writers are failed and keep failing when retry, or - * the Spark job fails with some unknown reasons, or {@link #commit(WriterCommitMessage[])} fails. + * Aborts this writing job because some data writers are failed and keep failing when retry, + * or the Spark job fails with some unknown reasons, + * or {@link #onDataWriterCommit(WriterCommitMessage)} fails, + * or {@link #commit(WriterCommitMessage[])} fails. * * If this method fails (by throwing an exception), the underlying data source may require manual * cleanup. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/StreamWriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/StreamWriteSupport.java similarity index 93% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/StreamWriteSupport.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/StreamWriteSupport.java index 7c5f304425093..1c0e2e12f8d51 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/StreamWriteSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/StreamWriteSupport.java @@ -15,14 +15,13 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.streaming; +package org.apache.spark.sql.sources.v2.writer; import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.execution.streaming.BaseStreamingSink; import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter; -import org.apache.spark.sql.sources.v2.writer.DataSourceWriter; +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter; import org.apache.spark.sql.streaming.OutputMode; import org.apache.spark.sql.types.StructType; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/StreamWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamWriter.java similarity index 98% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/StreamWriter.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamWriter.java index 915ee6c4fb390..4913341bd505d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/StreamWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamWriter.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.streaming.writer; +package org.apache.spark.sql.sources.v2.writer.streaming; import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.sources.v2.writer.DataSourceWriter; diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index 9803c3dec6de2..f3ece538c3b80 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -28,7 +28,7 @@ import org.apache.spark.unsafe.types.UTF8String; /** - * A column vector backed by Apache Arrow. Currently time interval type and map type are not + * A column vector backed by Apache Arrow. Currently calendar interval type and map type are not * supported. */ @InterfaceStability.Evolving @@ -37,6 +37,11 @@ public final class ArrowColumnVector extends ColumnVector { private final ArrowVectorAccessor accessor; private ArrowColumnVector[] childColumns; + @Override + public boolean hasNull() { + return accessor.getNullCount() > 0; + } + @Override public int numNulls() { return accessor.getNullCount(); @@ -114,6 +119,11 @@ public ColumnarArray getArray(int rowId) { return accessor.getArray(rowId); } + @Override + public ColumnarMap getMap(int rowId) { + throw new UnsupportedOperationException(); + } + @Override public ArrowColumnVector getChild(int ordinal) { return childColumns[ordinal]; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java index 4b955ceddd0f2..530d4d23d4eaf 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java @@ -17,9 +17,9 @@ package org.apache.spark.sql.vectorized; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; /** @@ -64,6 +64,11 @@ public abstract class ColumnVector implements AutoCloseable { @Override public abstract void close(); + /** + * Returns true if this column vector contains any null values. + */ + public abstract boolean hasNull(); + /** * Returns the number of nulls in this column vector. */ @@ -195,6 +200,7 @@ public double[] getDoubles(int rowId, int count) { * struct field. */ public final ColumnarRow getStruct(int rowId) { + if (isNullAt(rowId)) return null; return new ColumnarRow(this, rowId); } @@ -213,10 +219,18 @@ public final ColumnarRow getStruct(int rowId) { /** * Returns the map type value for rowId. + * + * In Spark, map type value is basically a key data array and a value data array. A key from the + * key array with a index and a value from the value array with the same index contribute to + * an entry of this map type value. + * + * To support map type, implementations must construct an {@link ColumnarMap} and return it in + * this method. {@link ColumnarMap} requires a {@link ColumnVector} that stores the data of all + * the keys of all the maps in this vector, and another {@link ColumnVector} that stores the data + * of all the values of all the maps in this vector, and a pair of offset and length which + * specify the range of the key/value array that belongs to the map type value at rowId. */ - public MapData getMap(int ordinal) { - throw new UnsupportedOperationException(); - } + public abstract ColumnarMap getMap(int ordinal); /** * Returns the decimal type value for rowId. @@ -236,9 +250,29 @@ public MapData getMap(int ordinal) { public abstract byte[] getBinary(int rowId); /** - * Returns the ordinal's child column vector. + * Returns the calendar interval type value for rowId. + * + * In Spark, calendar interval type value is basically an integer value representing the number of + * months in this interval, and a long value representing the number of microseconds in this + * interval. An interval type vector is the same as a struct type vector with 2 fields: `months` + * and `microseconds`. + * + * To support interval type, implementations must implement {@link #getChild(int)} and define 2 + * child vectors: the first child vector is an int type vector, containing all the month values of + * all the interval values in this vector. The second child vector is a long type vector, + * containing all the microsecond values of all the interval values in this vector. + */ + public final CalendarInterval getInterval(int rowId) { + if (isNullAt(rowId)) return null; + final int months = getChild(0).getInt(rowId); + final long microseconds = getChild(1).getLong(rowId); + return new CalendarInterval(months, microseconds); + } + + /** + * @return child [[ColumnVector]] at the given ordinal. */ - public abstract ColumnVector getChild(int ordinal); + protected abstract ColumnVector getChild(int ordinal); /** * Data type for this column. diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java index 0d2c3ec8648d3..72a192d089b9f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java @@ -18,7 +18,6 @@ import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.catalyst.util.ArrayData; -import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -135,9 +134,7 @@ public byte[] getBinary(int ordinal) { @Override public CalendarInterval getInterval(int ordinal) { - int month = data.getChild(0).getInt(offset + ordinal); - long microseconds = data.getChild(1).getLong(offset + ordinal); - return new CalendarInterval(month, microseconds); + return data.getInterval(offset + ordinal); } @Override @@ -151,8 +148,8 @@ public ColumnarArray getArray(int ordinal) { } @Override - public MapData getMap(int ordinal) { - throw new UnsupportedOperationException(); + public ColumnarMap getMap(int ordinal) { + return data.getMap(offset + ordinal); } @Override diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarMap.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarMap.java new file mode 100644 index 0000000000000..35648e386c4f1 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarMap.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.vectorized; + +import org.apache.spark.sql.catalyst.util.MapData; + +/** + * Map abstraction in {@link ColumnVector}. + */ +public final class ColumnarMap extends MapData { + private final ColumnarArray keys; + private final ColumnarArray values; + private final int length; + + public ColumnarMap(ColumnVector keys, ColumnVector values, int offset, int length) { + this.length = length; + this.keys = new ColumnarArray(keys, offset, length); + this.values = new ColumnarArray(values, offset, length); + } + + @Override + public int numElements() { return length; } + + @Override + public ColumnarArray keyArray() { + return keys; + } + + @Override + public ColumnarArray valueArray() { + return values; + } + + @Override + public ColumnarMap copy() { + throw new UnsupportedOperationException(); + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java index 25db7e09d20d0..b400f7f93c1fe 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java @@ -19,7 +19,6 @@ import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; -import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -139,9 +138,7 @@ public byte[] getBinary(int ordinal) { @Override public CalendarInterval getInterval(int ordinal) { if (data.getChild(ordinal).isNullAt(rowId)) return null; - final int months = data.getChild(ordinal).getChild(0).getInt(rowId); - final long microseconds = data.getChild(ordinal).getChild(1).getLong(rowId); - return new CalendarInterval(months, microseconds); + return data.getChild(ordinal).getInterval(rowId); } @Override @@ -157,8 +154,9 @@ public ColumnarArray getArray(int ordinal) { } @Override - public MapData getMap(int ordinal) { - throw new UnsupportedOperationException(); + public ColumnarMap getMap(int ordinal) { + if (data.getChild(ordinal).isNullAt(rowId)) return null; + return data.getChild(ordinal).getMap(rowId); } @Override diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index aa66ee7e948ea..ba1157d5b6a49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -90,16 +90,15 @@ case class RowDataSourceScanExec( Map("numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) protected override def doExecute(): RDD[InternalRow] = { - val unsafeRow = rdd.mapPartitionsWithIndexInternal { (index, iter) => + val numOutputRows = longMetric("numOutputRows") + + rdd.mapPartitionsWithIndexInternal { (index, iter) => val proj = UnsafeProjection.create(schema) proj.initialize(index) - iter.map(proj) - } - - val numOutputRows = longMetric("numOutputRows") - unsafeRow.map { r => - numOutputRows += 1 - r + iter.map( r => { + numOutputRows += 1 + proj(r) + }) } } @@ -326,22 +325,22 @@ case class FileSourceScanExec( // 2) the number of columns should be smaller than spark.sql.codegen.maxFields WholeStageCodegenExec(this)(codegenStageId = 0).execute() } else { - val unsafeRows = { - val scan = inputRDD - if (needsUnsafeRowConversion) { - scan.mapPartitionsWithIndexInternal { (index, iter) => - val proj = UnsafeProjection.create(schema) - proj.initialize(index) - iter.map(proj) - } - } else { - scan - } - } val numOutputRows = longMetric("numOutputRows") - unsafeRows.map { r => - numOutputRows += 1 - r + + if (needsUnsafeRowConversion) { + inputRDD.mapPartitionsWithIndexInternal { (index, iter) => + val proj = UnsafeProjection.create(schema) + proj.initialize(index) + iter.map( r => { + numOutputRows += 1 + proj(r) + }) + } + } else { + inputRDD.map { r => + numOutputRows += 1 + r + } } } } @@ -445,16 +444,29 @@ case class FileSourceScanExec( currentSize = 0 } - // Assign files to partitions using "Next Fit Decreasing" - splitFiles.foreach { file => - if (currentSize + file.length > maxSplitBytes) { - closePartition() + def addFile(file: PartitionedFile): Unit = { + currentFiles += file + currentSize += file.length + openCostInBytes + } + + var frontIndex = 0 + var backIndex = splitFiles.length - 1 + + while (frontIndex <= backIndex) { + addFile(splitFiles(frontIndex)) + frontIndex += 1 + while (frontIndex <= backIndex && + currentSize + splitFiles(frontIndex).length <= maxSplitBytes) { + addFile(splitFiles(frontIndex)) + frontIndex += 1 + } + while (backIndex > frontIndex && + currentSize + splitFiles(backIndex).length <= maxSplitBytes) { + addFile(splitFiles(backIndex)) + backIndex -= 1 } - // Add the given file to the current partition. - currentSize += file.length + openCostInBytes - currentFiles += file + closePartition() } - closePartition() new FileScanRDD(fsRelation.sparkSession, readFile, partitions) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index 2dd314d165348..dbf3bc6f0ee6c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -151,6 +151,7 @@ class OrcFileFormat val sqlConf = sparkSession.sessionState.conf val enableOffHeapColumnVector = sqlConf.offHeapColumnVectorEnabled val enableVectorizedReader = supportBatch(sparkSession, resultSchema) + val capacity = sqlConf.orcVectorizedReaderBatchSize val copyToSpark = sparkSession.sessionState.conf.getConf(SQLConf.ORC_COPY_BATCH_TO_SPARK) val broadcastedConf = @@ -186,7 +187,7 @@ class OrcFileFormat val taskContext = Option(TaskContext.get()) if (enableVectorizedReader) { val batchReader = new OrcColumnarBatchReader( - enableOffHeapColumnVector && taskContext.isDefined, copyToSpark) + enableOffHeapColumnVector && taskContext.isDefined, copyToSpark, capacity) batchReader.initialize(fileSplit, taskAttemptContext) batchReader.initBatch( reader.getSchema, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index f53a97ba45a26..ba69f9a26c968 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -350,6 +350,7 @@ class ParquetFileFormat sparkSession.sessionState.conf.parquetRecordFilterEnabled val timestampConversion: Boolean = sparkSession.sessionState.conf.isParquetINT96TimestampConversion + val capacity = sqlConf.parquetVectorizedReaderBatchSize // Whole stage codegen (PhysicalRDD) is able to deal with batches directly val returningBatch = supportBatch(sparkSession, resultSchema) @@ -396,7 +397,7 @@ class ParquetFileFormat val taskContext = Option(TaskContext.get()) val parquetReader = if (enableVectorizedReader) { val vectorizedReader = new VectorizedParquetRecordReader( - convertTz.orNull, enableOffHeapColumnVector && taskContext.isDefined) + convertTz.orNull, enableOffHeapColumnVector && taskContext.isDefined, capacity) vectorizedReader.initialize(split, hadoopAttemptContext) logDebug(s"Appending $partitionSchema ${file.partitionValues}") vectorizedReader.initBatch(partitionSchema, file.partitionValues) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala index 943d0100aca56..017a6737161a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Expression} import org.apache.spark.sql.catalyst.plans.physical -import org.apache.spark.sql.sources.v2.reader.{ClusteredDistribution, Partitioning} +import org.apache.spark.sql.sources.v2.reader.partitioning.{ClusteredDistribution, Partitioning} /** * An adapter from public data source partitioning to catalyst internal `Partitioning`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index c544adbf32cdf..eefbcf4c0e087 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -27,8 +27,8 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.streaming.continuous.{CommitPartitionEpoch, ContinuousExecution, EpochCoordinatorRef, SetWriterPartitions} -import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter import org.apache.spark.sql.sources.v2.writer._ +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -80,7 +80,10 @@ case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) e rdd, runTask, rdd.partitions.indices, - (index, message: WriterCommitMessage) => messages(index) = message + (index, message: WriterCommitMessage) => { + messages(index) = message + writer.onDataWriterCommit(message) + } ) if (!writer.isInstanceOf[StreamWriter]) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 77034ba0cfc19..045d2b4b9569c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -30,9 +30,9 @@ import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.sources.{InternalRowMicroBatchWriter, MicroBatchWriter} import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.streaming.{MicroBatchReadSupport, StreamWriteSupport} -import org.apache.spark.sql.sources.v2.streaming.reader.{MicroBatchReader, Offset => OffsetV2} -import org.apache.spark.sql.sources.v2.writer.SupportsWriteInternalRow +import org.apache.spark.sql.sources.v2.reader.MicroBatchReadSupport +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} +import org.apache.spark.sql.sources.v2.writer.{StreamWriteSupport, SupportsWriteInternalRow} import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.util.{Clock, Utils} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala index 5e3fee633f591..ce5e63f5bde85 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala @@ -30,11 +30,10 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader -import org.apache.spark.sql.execution.streaming.sources.RateStreamMicroBatchReader import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider} import org.apache.spark.sql.sources.v2._ -import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, MicroBatchReadSupport} -import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, MicroBatchReader} +import org.apache.spark.sql.sources.v2.reader.ContinuousReadSupport +import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader import org.apache.spark.sql.types._ import org.apache.spark.util.{ManualClock, SystemClock} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamOffset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamOffset.scala index 261d69bbd9843..02fed50485b94 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamOffset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamOffset.scala @@ -23,7 +23,7 @@ import org.json4s.jackson.Serialization import org.apache.spark.sql.sources.v2 case class RateStreamOffset(partitionToValueAndRunTimeMs: Map[Int, ValueRunTimeMsPair]) - extends v2.streaming.reader.Offset { + extends v2.reader.streaming.Offset { implicit val defaultFormats: DefaultFormats = DefaultFormats override val json = Serialization.write(partitionToValueAndRunTimeMs) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala index a0ee683a895d8..845c8d2c14e43 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.execution.LeafExecNode import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.sources.v2.DataSourceV2 -import org.apache.spark.sql.sources.v2.streaming.ContinuousReadSupport +import org.apache.spark.sql.sources.v2.reader.ContinuousReadSupport object StreamingRelation { def apply(dataSource: DataSource): StreamingRelation = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index 3f5bb489d6528..db600866067bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -21,8 +21,8 @@ import org.apache.spark.sql._ import org.apache.spark.sql.execution.streaming.sources.ConsoleWriter import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister} import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2} -import org.apache.spark.sql.sources.v2.streaming.StreamWriteSupport -import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter +import org.apache.spark.sql.sources.v2.writer.StreamWriteSupport +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala index 8a7a38b22caca..cf02c0dda25d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala @@ -18,23 +18,19 @@ package org.apache.spark.sql.execution.streaming.continuous import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue, TimeUnit} -import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong} +import java.util.concurrent.atomic.AtomicBoolean import scala.collection.JavaConverters._ import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDDPartition, RowToUnsafeDataReader} -import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousDataReader, PartitionOffset} -import org.apache.spark.sql.streaming.ProcessingTime -import org.apache.spark.util.{SystemClock, ThreadUtils} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, PartitionOffset} +import org.apache.spark.util.ThreadUtils class ContinuousDataSourceRDD( sc: SparkContext, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 9402d7c1dcefd..08c81419a9d34 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -32,8 +32,9 @@ import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, StreamingDataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _} import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, StreamWriteSupport} -import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.ContinuousReadSupport +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset} +import org.apache.spark.sql.sources.v2.writer.StreamWriteSupport import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.sql.types.StructType import org.apache.spark.util.{Clock, Utils} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index ff028ebc4236a..0eaaa4889ba9e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.streaming.{RateSourceProvider, RateStreamO import org.apache.spark.sql.execution.streaming.sources.RateStreamSourceV2 import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} import org.apache.spark.sql.types.StructType case class RateStreamPartitionOffset( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala index 84d262116cb46..cc6808065c0cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala @@ -23,9 +23,9 @@ import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, PartitionOffset} -import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset} import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.util.RpcUtils private[continuous] sealed trait EpochCoordinatorMessage extends Serializable diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala index d46f4d7b86360..d276403190b3c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql.execution.streaming.sources +import scala.collection.JavaConverters._ + import org.apache.spark.internal.Logging import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter import org.apache.spark.sql.sources.v2.writer.{DataWriterFactory, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.types.StructType /** Common methods used to create writes for the the console sink */ @@ -61,7 +63,7 @@ class ConsoleWriter(schema: StructType, options: DataSourceOptions) println("-------------------------------------------") // scalastyle:off println spark - .createDataFrame(spark.sparkContext.parallelize(rows), schema) + .createDataFrame(rows.toList.asJava, schema) .show(numRowsToShow, isTruncated) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala index d7ce9a7b84479..56f7ff25cbed0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql.execution.streaming.sources import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriterFactory, SupportsWriteInternalRow, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter /** * A [[DataSourceWriter]] used to hook V2 stream writers into a microbatch plan. It implements diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala index 6de113f21ce28..077a255946a6b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala @@ -31,8 +31,7 @@ import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeM import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2} import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.streaming.MicroBatchReadSupport -import org.apache.spark.sql.sources.v2.streaming.reader.{MicroBatchReader, Offset} +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampType} import org.apache.spark.util.{ManualClock, SystemClock} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index 58767261dc684..3411edbc53412 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -30,9 +30,8 @@ import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Complete, Update} import org.apache.spark.sql.execution.streaming.Sink import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2} -import org.apache.spark.sql.sources.v2.streaming.StreamWriteSupport -import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter import org.apache.spark.sql.sources.v2.writer._ +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index f1b3f93c4e1fc..116ac3da07b75 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.{StreamingRelation, StreamingRelationV2} import org.apache.spark.sql.sources.StreamSourceProvider import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.reader.{ContinuousReadSupport, MicroBatchReadSupport} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 3b5b30d77945c..9aac360fd4bbc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.execution.streaming.sources.{MemoryPlanV2, MemorySinkV2} -import org.apache.spark.sql.sources.v2.streaming.StreamWriteSupport +import org.apache.spark.sql.sources.v2.writer.StreamWriteSupport /** * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index fdd709cdb1f38..ddb1edc433d5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, ContinuousTrigger} import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.streaming.StreamWriteSupport +import org.apache.spark.sql.sources.v2.writer.StreamWriteSupport import org.apache.spark.util.{Clock, SystemClock, Utils} /** diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java index 99cca0f6dd626..32fad59b97ff6 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java @@ -27,6 +27,9 @@ import org.apache.spark.sql.sources.v2.DataSourceV2; import org.apache.spark.sql.sources.v2.ReadSupport; import org.apache.spark.sql.sources.v2.reader.*; +import org.apache.spark.sql.sources.v2.reader.partitioning.ClusteredDistribution; +import org.apache.spark.sql.sources.v2.reader.partitioning.Distribution; +import org.apache.spark.sql.sources.v2.reader.partitioning.Partitioning; import org.apache.spark.sql.types.StructType; public class JavaPartitionAwareDataSource implements DataSourceV2, ReadSupport { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 33707080c1301..8b66f77b2f923 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -589,6 +589,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Nil) } + test("SPARK-23274: except between two projects without references used in filter") { + val df = Seq((1, 2, 4), (1, 3, 5), (2, 2, 3), (2, 4, 5)).toDF("a", "b", "c") + val df1 = df.filter($"a" === 1) + val df2 = df.filter($"a" === 2) + checkAnswer(df1.select("b").except(df2.select("b")), Row(3) :: Nil) + checkAnswer(df1.select("b").except(df2.select("c")), Row(2) :: Nil) + } + test("except distinct - SQL compliance") { val df_left = Seq(1, 2, 2, 3, 3, 4).toDF("id") val df_right = Seq(1, 3).toDF("id") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index ffd736d2ebbb6..8f14575c3325f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql import java.io.File -import java.math.MathContext import java.net.{MalformedURLException, URL} import java.sql.Timestamp import java.util.concurrent.atomic.AtomicBoolean @@ -1618,6 +1617,46 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } + test("SPARK-23281: verify the correctness of sort direction on composite order by clause") { + withTempView("src") { + Seq[(Integer, Integer)]( + (1, 1), + (1, 3), + (2, 3), + (3, 3), + (4, null), + (5, null) + ).toDF("key", "value").createOrReplaceTempView("src") + + checkAnswer(sql( + """ + |SELECT MAX(value) as value, key as col2 + |FROM src + |GROUP BY key + |ORDER BY value desc, key + """.stripMargin), + Seq(Row(3, 1), Row(3, 2), Row(3, 3), Row(null, 4), Row(null, 5))) + + checkAnswer(sql( + """ + |SELECT MAX(value) as value, key as col2 + |FROM src + |GROUP BY key + |ORDER BY value desc, key desc + """.stripMargin), + Seq(Row(3, 3), Row(3, 2), Row(3, 1), Row(null, 5), Row(null, 4))) + + checkAnswer(sql( + """ + |SELECT MAX(value) as value, key as col2 + |FROM src + |GROUP BY key + |ORDER BY value asc, key desc + """.stripMargin), + Seq(Row(null, 5), Row(null, 4), Row(3, 3), Row(3, 2), Row(3, 1))) + } + } + test("run sql directly on files") { val df = spark.range(100).toDF() withTempPath(f => { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index c1d61b843d899..bfccc9335b361 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -141,16 +141,17 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi withSQLConf(SQLConf.FILES_MAX_PARTITION_BYTES.key -> "4", SQLConf.FILES_OPEN_COST_IN_BYTES.key -> "1") { checkScan(table.select('c1)) { partitions => - // Files should be laid out [(file1), (file2, file3), (file4, file5), (file6)] - assert(partitions.size == 4, "when checking partitions") - assert(partitions(0).files.size == 1, "when checking partition 1") + // Files should be laid out [(file1, file6), (file2, file3), (file4, file5)] + assert(partitions.size == 3, "when checking partitions") + assert(partitions(0).files.size == 2, "when checking partition 1") assert(partitions(1).files.size == 2, "when checking partition 2") assert(partitions(2).files.size == 2, "when checking partition 3") - assert(partitions(3).files.size == 1, "when checking partition 4") - // First partition reads (file1) + // First partition reads (file1, file6) assert(partitions(0).files(0).start == 0) assert(partitions(0).files(0).length == 2) + assert(partitions(0).files(1).start == 0) + assert(partitions(0).files(1).length == 1) // Second partition reads (file2, file3) assert(partitions(1).files(0).start == 0) @@ -163,10 +164,6 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi assert(partitions(2).files(0).length == 1) assert(partitions(2).files(1).start == 0) assert(partitions(2).files(1).length == 1) - - // Final partition reads (file6) - assert(partitions(3).files(0).start == 0) - assert(partitions(3).files(0).length == 1) } checkPartitionSchema(StructType(Nil)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala index edb1290ee2eb0..db73bfa149aa0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala @@ -40,7 +40,9 @@ class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSQLContex List.fill(n)(ROW).toDF.repartition(1).write.parquet(dir.getCanonicalPath) val file = SpecificParquetRecordReaderBase.listDirectory(dir).toArray.head - val reader = new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled) + val conf = sqlContext.conf + val reader = new VectorizedParquetRecordReader( + null, conf.offHeapColumnVectorEnabled, conf.parquetVectorizedReaderBatchSize) reader.initialize(file.asInstanceOf[String], null) val batch = reader.resultBatch() assert(reader.nextBatch()) @@ -65,7 +67,9 @@ class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSQLContex data.repartition(1).write.parquet(dir.getCanonicalPath) val file = SpecificParquetRecordReaderBase.listDirectory(dir).toArray.head - val reader = new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled) + val conf = sqlContext.conf + val reader = new VectorizedParquetRecordReader( + null, conf.offHeapColumnVectorEnabled, conf.parquetVectorizedReaderBatchSize) reader.initialize(file.asInstanceOf[String], null) val batch = reader.resultBatch() assert(reader.nextBatch()) @@ -94,7 +98,9 @@ class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSQLContex data.toDF("f").coalesce(1).write.parquet(dir.getCanonicalPath) val file = SpecificParquetRecordReaderBase.listDirectory(dir).asScala.head - val reader = new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled) + val conf = sqlContext.conf + val reader = new VectorizedParquetRecordReader( + null, conf.offHeapColumnVectorEnabled, conf.parquetVectorizedReaderBatchSize) reader.initialize(file, null /* set columns to null to project all columns */) val column = reader.resultBatch().column(0) assert(reader.nextBatch()) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index f3ece5b15e26a..3af80930ec807 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -653,7 +653,9 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { spark.createDataFrame(data).repartition(1).write.parquet(dir.getCanonicalPath) val file = SpecificParquetRecordReaderBase.listDirectory(dir).get(0); { - val reader = new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled) + val conf = sqlContext.conf + val reader = new VectorizedParquetRecordReader( + null, conf.offHeapColumnVectorEnabled, conf.parquetVectorizedReaderBatchSize) try { reader.initialize(file, null) val result = mutable.ArrayBuffer.empty[(Int, String)] @@ -670,7 +672,9 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { // Project just one column { - val reader = new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled) + val conf = sqlContext.conf + val reader = new VectorizedParquetRecordReader( + null, conf.offHeapColumnVectorEnabled, conf.parquetVectorizedReaderBatchSize) try { reader.initialize(file, ("_2" :: Nil).asJava) val result = mutable.ArrayBuffer.empty[(String)] @@ -686,7 +690,9 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { // Project columns in opposite order { - val reader = new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled) + val conf = sqlContext.conf + val reader = new VectorizedParquetRecordReader( + null, conf.offHeapColumnVectorEnabled, conf.parquetVectorizedReaderBatchSize) try { reader.initialize(file, ("_2" :: "_1" :: Nil).asJava) val result = mutable.ArrayBuffer.empty[(String, Int)] @@ -703,7 +709,9 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { // Empty projection { - val reader = new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled) + val conf = sqlContext.conf + val reader = new VectorizedParquetRecordReader( + null, conf.offHeapColumnVectorEnabled, conf.parquetVectorizedReaderBatchSize) try { reader.initialize(file, List[String]().asJava) var result = 0 @@ -742,8 +750,9 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { dataTypes.zip(constantValues).foreach { case (dt, v) => val schema = StructType(StructField("pcol", dt) :: Nil) - val vectorizedReader = - new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled) + val conf = sqlContext.conf + val vectorizedReader = new VectorizedParquetRecordReader( + null, conf.offHeapColumnVectorEnabled, conf.parquetVectorizedReaderBatchSize) val partitionValues = new GenericInternalRow(Array(v)) val file = SpecificParquetRecordReaderBase.listDirectory(dir).get(0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala index 86a3c71a3c4f6..e43336d947364 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala @@ -76,6 +76,7 @@ object ParquetReadBenchmark { withTempPath { dir => withTempTable("t1", "tempTable") { val enableOffHeapColumnVector = spark.sessionState.conf.offHeapColumnVectorEnabled + val vectorizedReaderBatchSize = spark.sessionState.conf.parquetVectorizedReaderBatchSize spark.range(values).createOrReplaceTempView("t1") spark.sql("select cast(id as INT) as id from t1") .write.parquet(dir.getCanonicalPath) @@ -96,7 +97,8 @@ object ParquetReadBenchmark { parquetReaderBenchmark.addCase("ParquetReader Vectorized") { num => var sum = 0L files.map(_.asInstanceOf[String]).foreach { p => - val reader = new VectorizedParquetRecordReader(enableOffHeapColumnVector) + val reader = new VectorizedParquetRecordReader( + null, enableOffHeapColumnVector, vectorizedReaderBatchSize) try { reader.initialize(p, ("id" :: Nil).asJava) val batch = reader.resultBatch() @@ -119,7 +121,8 @@ object ParquetReadBenchmark { parquetReaderBenchmark.addCase("ParquetReader Vectorized -> Row") { num => var sum = 0L files.map(_.asInstanceOf[String]).foreach { p => - val reader = new VectorizedParquetRecordReader(enableOffHeapColumnVector) + val reader = new VectorizedParquetRecordReader( + null, enableOffHeapColumnVector, vectorizedReaderBatchSize) try { reader.initialize(p, ("id" :: Nil).asJava) val batch = reader.resultBatch() @@ -262,6 +265,7 @@ object ParquetReadBenchmark { withTempPath { dir => withTempTable("t1", "tempTable") { val enableOffHeapColumnVector = spark.sessionState.conf.offHeapColumnVectorEnabled + val vectorizedReaderBatchSize = spark.sessionState.conf.parquetVectorizedReaderBatchSize spark.range(values).createOrReplaceTempView("t1") spark.sql(s"select IF(rand(1) < $fractionOfNulls, NULL, cast(id as STRING)) as c1, " + s"IF(rand(2) < $fractionOfNulls, NULL, cast(id as STRING)) as c2 from t1") @@ -279,7 +283,8 @@ object ParquetReadBenchmark { benchmark.addCase("PR Vectorized") { num => var sum = 0 files.map(_.asInstanceOf[String]).foreach { p => - val reader = new VectorizedParquetRecordReader(enableOffHeapColumnVector) + val reader = new VectorizedParquetRecordReader( + null, enableOffHeapColumnVector, vectorizedReaderBatchSize) try { reader.initialize(p, ("c1" :: "c2" :: Nil).asJava) val batch = reader.resultBatch() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala index b060aeeef811d..3158995ec62f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.execution.streaming.sources.{RateStreamBatchTask, RateStreamMicroBatchReader, RateStreamSourceV2} import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.reader.{ContinuousReadSupport, MicroBatchReadSupport} import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.util.ManualClock diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala index e794f50781ff2..b55489cb2678a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala @@ -42,6 +42,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === BooleanType) + assert(columnVector.hasNull) assert(columnVector.numNulls === 1) (0 until 10).foreach { i => @@ -69,6 +70,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === ByteType) + assert(columnVector.hasNull) assert(columnVector.numNulls === 1) (0 until 10).foreach { i => @@ -96,6 +98,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === ShortType) + assert(columnVector.hasNull) assert(columnVector.numNulls === 1) (0 until 10).foreach { i => @@ -123,6 +126,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === IntegerType) + assert(columnVector.hasNull) assert(columnVector.numNulls === 1) (0 until 10).foreach { i => @@ -150,6 +154,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === LongType) + assert(columnVector.hasNull) assert(columnVector.numNulls === 1) (0 until 10).foreach { i => @@ -177,6 +182,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === FloatType) + assert(columnVector.hasNull) assert(columnVector.numNulls === 1) (0 until 10).foreach { i => @@ -204,6 +210,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === DoubleType) + assert(columnVector.hasNull) assert(columnVector.numNulls === 1) (0 until 10).foreach { i => @@ -232,6 +239,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === StringType) + assert(columnVector.hasNull) assert(columnVector.numNulls === 1) (0 until 10).foreach { i => @@ -258,6 +266,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === BinaryType) + assert(columnVector.hasNull) assert(columnVector.numNulls === 1) (0 until 10).foreach { i => @@ -300,6 +309,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === ArrayType(IntegerType)) + assert(columnVector.hasNull) assert(columnVector.numNulls === 1) val array0 = columnVector.getArray(0) @@ -344,6 +354,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === schema) + assert(!columnVector.hasNull) assert(columnVector.numNulls === 0) val row0 = columnVector.getStruct(0) @@ -396,6 +407,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === schema) + assert(columnVector.hasNull) assert(columnVector.numNulls === 1) val row0 = columnVector.getStruct(0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 1873c24ab063c..8fe2985836f2e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -66,22 +66,27 @@ class ColumnarBatchSuite extends SparkFunSuite { column => val reference = mutable.ArrayBuffer.empty[Boolean] var idx = 0 + assert(!column.hasNull) assert(column.numNulls() == 0) column.appendNotNull() reference += false + assert(!column.hasNull) assert(column.numNulls() == 0) column.appendNotNulls(3) (1 to 3).foreach(_ => reference += false) + assert(!column.hasNull) assert(column.numNulls() == 0) column.appendNull() reference += true + assert(column.hasNull) assert(column.numNulls() == 1) column.appendNulls(3) (1 to 3).foreach(_ => reference += true) + assert(column.hasNull) assert(column.numNulls() == 4) idx = column.elementsAppended @@ -89,11 +94,13 @@ class ColumnarBatchSuite extends SparkFunSuite { column.putNotNull(idx) reference += false idx += 1 + assert(column.hasNull) assert(column.numNulls() == 4) column.putNull(idx) reference += true idx += 1 + assert(column.hasNull) assert(column.numNulls() == 5) column.putNulls(idx, 3) @@ -101,6 +108,7 @@ class ColumnarBatchSuite extends SparkFunSuite { reference += true reference += true idx += 3 + assert(column.hasNull) assert(column.numNulls() == 8) column.putNotNulls(idx, 4) @@ -109,6 +117,7 @@ class ColumnarBatchSuite extends SparkFunSuite { reference += false reference += false idx += 4 + assert(column.hasNull) assert(column.numNulls() == 8) reference.zipWithIndex.foreach { v => @@ -620,6 +629,39 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(column.arrayData().elementsAppended == 0) } + testVector("CalendarInterval APIs", 4, CalendarIntervalType) { + column => + val reference = mutable.ArrayBuffer.empty[CalendarInterval] + + val months = column.getChild(0) + val microseconds = column.getChild(1) + assert(months.dataType() == IntegerType) + assert(microseconds.dataType() == LongType) + + months.putInt(0, 1) + microseconds.putLong(0, 100) + reference += new CalendarInterval(1, 100) + + months.putInt(1, 0) + microseconds.putLong(1, 2000) + reference += new CalendarInterval(0, 2000) + + column.putNull(2) + reference += null + + months.putInt(3, 20) + microseconds.putLong(3, 0) + reference += new CalendarInterval(20, 0) + + reference.zipWithIndex.foreach { case (v, i) => + val errMsg = "VectorType=" + column.getClass.getSimpleName + assert(v == column.getInterval(i), errMsg) + if (v == null) assert(column.isNullAt(i), errMsg) + } + + column.close() + } + testVector("Int Array", 10, new ArrayType(IntegerType, true)) { column => @@ -631,35 +673,37 @@ class ColumnarBatchSuite extends SparkFunSuite { i += 1 } - // Populate it with arrays [0], [1, 2], [], [3, 4, 5] + // Populate it with arrays [0], [1, 2], null, [], [3, 4, 5] column.putArray(0, 0, 1) column.putArray(1, 1, 2) - column.putArray(2, 2, 0) - column.putArray(3, 3, 3) + column.putNull(2) + column.putArray(3, 3, 0) + column.putArray(4, 3, 3) + + assert(column.getArray(0).numElements == 1) + assert(column.getArray(1).numElements == 2) + assert(column.isNullAt(2)) + assert(column.getArray(3).numElements == 0) + assert(column.getArray(4).numElements == 3) val a1 = ColumnVectorUtils.toJavaIntArray(column.getArray(0)) val a2 = ColumnVectorUtils.toJavaIntArray(column.getArray(1)) - val a3 = ColumnVectorUtils.toJavaIntArray(column.getArray(2)) - val a4 = ColumnVectorUtils.toJavaIntArray(column.getArray(3)) + val a3 = ColumnVectorUtils.toJavaIntArray(column.getArray(3)) + val a4 = ColumnVectorUtils.toJavaIntArray(column.getArray(4)) assert(a1 === Array(0)) assert(a2 === Array(1, 2)) assert(a3 === Array.empty[Int]) assert(a4 === Array(3, 4, 5)) - // Verify the ArrayData APIs - assert(column.getArray(0).numElements() == 1) + // Verify the ArrayData get APIs assert(column.getArray(0).getInt(0) == 0) - assert(column.getArray(1).numElements() == 2) assert(column.getArray(1).getInt(0) == 1) assert(column.getArray(1).getInt(1) == 2) - assert(column.getArray(2).numElements() == 0) - - assert(column.getArray(3).numElements() == 3) - assert(column.getArray(3).getInt(0) == 3) - assert(column.getArray(3).getInt(1) == 4) - assert(column.getArray(3).getInt(2) == 5) + assert(column.getArray(4).getInt(0) == 3) + assert(column.getArray(4).getInt(1) == 4) + assert(column.getArray(4).getInt(2) == 5) // Add a longer array which requires resizing column.reset() @@ -669,8 +713,7 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(data.capacity == array.length * 2) data.putInts(0, array.length, array, 0) column.putArray(0, 0, array.length) - assert(ColumnVectorUtils.toJavaIntArray(column.getArray(0)) - === array) + assert(ColumnVectorUtils.toJavaIntArray(column.getArray(0)) === array) } test("toArray for primitive types") { @@ -728,6 +771,43 @@ class ColumnarBatchSuite extends SparkFunSuite { } } + test("Int Map") { + (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => + val column = allocate(10, new MapType(IntegerType, IntegerType, false), memMode) + (0 to 1).foreach { colIndex => + val data = column.getChild(colIndex) + (0 to 5).foreach {i => + data.putInt(i, i * (colIndex + 1)) + } + } + + // Populate it with maps [0->0], [1->2, 2->4], null, [], [3->6, 4->8, 5->10] + column.putArray(0, 0, 1) + column.putArray(1, 1, 2) + column.putNull(2) + column.putArray(3, 3, 0) + column.putArray(4, 3, 3) + + assert(column.getMap(0).numElements == 1) + assert(column.getMap(1).numElements == 2) + assert(column.isNullAt(2)) + assert(column.getMap(3).numElements == 0) + assert(column.getMap(4).numElements == 3) + + val a1 = ColumnVectorUtils.toJavaIntMap(column.getMap(0)) + val a2 = ColumnVectorUtils.toJavaIntMap(column.getMap(1)) + val a4 = ColumnVectorUtils.toJavaIntMap(column.getMap(3)) + val a5 = ColumnVectorUtils.toJavaIntMap(column.getMap(4)) + + assert(a1.asScala == Map(0 -> 0)) + assert(a2.asScala == Map(1 -> 2, 2 -> 4)) + assert(a4.asScala == Map()) + assert(a5.asScala == Map(3 -> 6, 4 -> 8, 5 -> 10)) + + column.close() + } + } + testVector( "Struct Column", 10, @@ -739,14 +819,20 @@ class ColumnarBatchSuite extends SparkFunSuite { c1.putInt(0, 123) c2.putDouble(0, 3.45) - c1.putInt(1, 456) - c2.putDouble(1, 5.67) + + column.putNull(1) + + c1.putInt(2, 456) + c2.putDouble(2, 5.67) val s = column.getStruct(0) assert(s.getInt(0) == 123) assert(s.getDouble(1) == 3.45) - val s2 = column.getStruct(1) + assert(column.isNullAt(1)) + assert(column.getStruct(1) == null) + + val s2 = column.getStruct(2) assert(s2.getInt(0) == 456) assert(s2.getDouble(1) == 5.67) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index ee50e8a92270b..23147fffe8a08 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -21,13 +21,14 @@ import java.util.{ArrayList, List => JList} import test.org.apache.spark.sql.sources.v2._ -import org.apache.spark.SparkException +import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.sources.{Filter, GreaterThan} import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.partitioning.{ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StructType} import org.apache.spark.sql.vectorized.ColumnarBatch @@ -197,6 +198,31 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } } + + test("simple counter in writer with onDataWriterCommit") { + Seq(classOf[SimpleWritableDataSource]).foreach { cls => + withTempPath { file => + val path = file.getCanonicalPath + assert(spark.read.format(cls.getName).option("path", path).load().collect().isEmpty) + + val numPartition = 6 + spark.range(0, 10, 1, numPartition).select('id, -'id).write.format(cls.getName) + .option("path", path).save() + checkAnswer( + spark.read.format(cls.getName).option("path", path).load(), + spark.range(10).select('id, -'id)) + + assert(SimpleCounter.getCounter == numPartition, + "method onDataWriterCommit should be called as many as the number of partitions") + } + } + } + + test("SPARK-23293: data source v2 self join") { + val df = spark.read.format(classOf[SimpleDataSourceV2].getName).load() + val df2 = df.select(($"i" + 1).as("k"), $"j") + checkAnswer(df.join(df2, "j"), (0 until 10).map(i => Row(-i, i, i + 1))) + } } class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala index a131b16953e3b..36dd2a350a055 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -66,9 +66,14 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS class Writer(jobId: String, path: String, conf: Configuration) extends DataSourceWriter { override def createWriterFactory(): DataWriterFactory[Row] = { + SimpleCounter.resetCounter new SimpleCSVDataWriterFactory(path, jobId, new SerializableConfiguration(conf)) } + override def onDataWriterCommit(message: WriterCommitMessage): Unit = { + SimpleCounter.increaseCounter + } + override def commit(messages: Array[WriterCommitMessage]): Unit = { val finalPath = new Path(path) val jobPath = new Path(new Path(finalPath, "_temporary"), jobId) @@ -183,6 +188,22 @@ class SimpleCSVDataReaderFactory(path: String, conf: SerializableConfiguration) } } +private[v2] object SimpleCounter { + private var count: Int = 0 + + def increaseCounter: Unit = { + count += 1 + } + + def getCounter: Int = { + count + } + + def resetCounter: Unit = { + count = 0 + } +} + class SimpleCSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration) extends DataWriterFactory[Row] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index 3127d664d32dc..cb873ab688e96 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -26,10 +26,10 @@ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider} import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.reader.DataReaderFactory -import org.apache.spark.sql.sources.v2.streaming._ -import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, MicroBatchReader, Offset, PartitionOffset} -import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter +import org.apache.spark.sql.sources.v2.reader.{ContinuousReadSupport, DataReaderFactory, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, MicroBatchReader, Offset, PartitionOffset} +import org.apache.spark.sql.sources.v2.writer.StreamWriteSupport +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.{OutputMode, StreamTest, Trigger} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java index 108074cce3d6d..fc818bc69c761 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java @@ -44,7 +44,7 @@ import org.apache.hadoop.hive.ql.history.HiveHistory; import org.apache.hadoop.hive.ql.metadata.Hive; import org.apache.hadoop.hive.ql.metadata.HiveException; -import org.apache.hadoop.hive.ql.processors.SetProcessor; +import org.apache.hadoop.hive.ql.parse.VariableSubstitution; import org.apache.hadoop.hive.ql.session.SessionState; import org.apache.hadoop.hive.shims.ShimLoader; import org.apache.hive.common.util.HiveVersionInfo; @@ -71,6 +71,12 @@ import org.apache.hive.service.cli.thrift.TProtocolVersion; import org.apache.hive.service.server.ThreadWithGarbageCleanup; +import static org.apache.hadoop.hive.conf.SystemVariables.ENV_PREFIX; +import static org.apache.hadoop.hive.conf.SystemVariables.HIVECONF_PREFIX; +import static org.apache.hadoop.hive.conf.SystemVariables.HIVEVAR_PREFIX; +import static org.apache.hadoop.hive.conf.SystemVariables.METACONF_PREFIX; +import static org.apache.hadoop.hive.conf.SystemVariables.SYSTEM_PREFIX; + /** * HiveSession * @@ -209,7 +215,7 @@ private void configureSession(Map sessionConfMap) throws HiveSQL String key = entry.getKey(); if (key.startsWith("set:")) { try { - SetProcessor.setVariable(key.substring(4), entry.getValue()); + setVariable(key.substring(4), entry.getValue()); } catch (Exception e) { throw new HiveSQLException(e); } @@ -221,6 +227,70 @@ private void configureSession(Map sessionConfMap) throws HiveSQL } } + // Copy from org.apache.hadoop.hive.ql.processors.SetProcessor, only change: + // setConf(varname, propName, varvalue, true) when varname.startsWith(HIVECONF_PREFIX) + public static int setVariable(String varname, String varvalue) throws Exception { + SessionState ss = SessionState.get(); + if (varvalue.contains("\n")){ + ss.err.println("Warning: Value had a \\n character in it."); + } + varname = varname.trim(); + if (varname.startsWith(ENV_PREFIX)){ + ss.err.println("env:* variables can not be set."); + return 1; + } else if (varname.startsWith(SYSTEM_PREFIX)){ + String propName = varname.substring(SYSTEM_PREFIX.length()); + System.getProperties().setProperty(propName, + new VariableSubstitution().substitute(ss.getConf(),varvalue)); + } else if (varname.startsWith(HIVECONF_PREFIX)){ + String propName = varname.substring(HIVECONF_PREFIX.length()); + setConf(varname, propName, varvalue, true); + } else if (varname.startsWith(HIVEVAR_PREFIX)) { + String propName = varname.substring(HIVEVAR_PREFIX.length()); + ss.getHiveVariables().put(propName, + new VariableSubstitution().substitute(ss.getConf(),varvalue)); + } else if (varname.startsWith(METACONF_PREFIX)) { + String propName = varname.substring(METACONF_PREFIX.length()); + Hive hive = Hive.get(ss.getConf()); + hive.setMetaConf(propName, new VariableSubstitution().substitute(ss.getConf(), varvalue)); + } else { + setConf(varname, varname, varvalue, true); + } + return 0; + } + + // returns non-null string for validation fail + private static void setConf(String varname, String key, String varvalue, boolean register) + throws IllegalArgumentException { + HiveConf conf = SessionState.get().getConf(); + String value = new VariableSubstitution().substitute(conf, varvalue); + if (conf.getBoolVar(HiveConf.ConfVars.HIVECONFVALIDATION)) { + HiveConf.ConfVars confVars = HiveConf.getConfVars(key); + if (confVars != null) { + if (!confVars.isType(value)) { + StringBuilder message = new StringBuilder(); + message.append("'SET ").append(varname).append('=').append(varvalue); + message.append("' FAILED because ").append(key).append(" expects "); + message.append(confVars.typeString()).append(" type value."); + throw new IllegalArgumentException(message.toString()); + } + String fail = confVars.validate(value); + if (fail != null) { + StringBuilder message = new StringBuilder(); + message.append("'SET ").append(varname).append('=').append(varvalue); + message.append("' FAILED in validation : ").append(fail).append('.'); + throw new IllegalArgumentException(message.toString()); + } + } else if (key.startsWith("hive.")) { + throw new IllegalArgumentException("hive configuration " + key + " does not exists."); + } + } + conf.verifyAndSet(key, value); + if (register) { + SessionState.get().getOverriddenConfigurations().put(key, value); + } + } + @Override public void setOperationLogSessionDir(File operationLogRootDir) { if (!operationLogRootDir.exists()) { diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index 664bc20601eaa..3cfc81b8a9579 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -102,7 +102,7 @@ private[hive] class SparkExecuteStatementOperation( to += from.getAs[Timestamp](ordinal) case BinaryType => to += from.getAs[Array[Byte]](ordinal) - case _: ArrayType | _: StructType | _: MapType => + case _: ArrayType | _: StructType | _: MapType | _: UserDefinedType[_] => val hiveString = HiveUtils.toHiveString((from.get(ordinal), dataTypes(ordinal))) to += hiveString } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index a0e5012633f5e..bf7c01f60fb5c 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -28,6 +28,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.SQLContext import org.apache.spark.sql.hive.HiveUtils import org.apache.spark.sql.hive.thriftserver.{ReflectionUtils, SparkExecuteStatementOperation} +import org.apache.spark.sql.internal.SQLConf /** * Executes queries using Spark SQL, and maintains a list of handles to active queries. @@ -50,6 +51,9 @@ private[thriftserver] class SparkSQLOperationManager() require(sqlContext != null, s"Session handle: ${parentSession.getSessionHandle} has not been" + s" initialized or had already closed.") val conf = sqlContext.sessionState.conf + val hiveSessionState = parentSession.getSessionState + setConfMap(conf, hiveSessionState.getOverriddenConfigurations) + setConfMap(conf, hiveSessionState.getHiveVariables) val runInBackground = async && conf.getConf(HiveUtils.HIVE_THRIFT_SERVER_ASYNC) val operation = new SparkExecuteStatementOperation(parentSession, statement, confOverlay, runInBackground)(sqlContext, sessionToActivePool) @@ -58,4 +62,12 @@ private[thriftserver] class SparkSQLOperationManager() s"runInBackground=$runInBackground") operation } + + def setConfMap(conf: SQLConf, confMap: java.util.Map[String, String]): Unit = { + val iterator = confMap.entrySet().iterator() + while (iterator.hasNext) { + val kv = iterator.next() + conf.setConfString(kv.getKey, kv.getValue) + } + } } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index 7289da71a3365..496f8c82a6c61 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -135,6 +135,22 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { } } + test("Support beeline --hiveconf and --hivevar") { + withJdbcStatement() { statement => + executeTest(hiveConfList) + executeTest(hiveVarList) + def executeTest(hiveList: String): Unit = { + hiveList.split(";").foreach{ m => + val kv = m.split("=") + // select "${a}"; ---> avalue + val resultSet = statement.executeQuery("select \"${" + kv(0) + "}\"") + resultSet.next() + assert(resultSet.getString(1) === kv(1)) + } + } + } + } + test("JDBC query execution") { withJdbcStatement("test") { statement => val queries = Seq( @@ -740,10 +756,11 @@ abstract class HiveThriftJdbcTest extends HiveThriftServer2Test { s"""jdbc:hive2://localhost:$serverPort/ |default? |hive.server2.transport.mode=http; - |hive.server2.thrift.http.path=cliservice + |hive.server2.thrift.http.path=cliservice; + |${hiveConfList}#${hiveVarList} """.stripMargin.split("\n").mkString.trim } else { - s"jdbc:hive2://localhost:$serverPort/" + s"jdbc:hive2://localhost:$serverPort/?${hiveConfList}#${hiveVarList}" } def withMultipleConnectionJdbcStatement(tableNames: String*)(fs: (Statement => Unit)*) { @@ -779,6 +796,8 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl private var listeningPort: Int = _ protected def serverPort: Int = listeningPort + protected val hiveConfList = "a=avalue;b=bvalue" + protected val hiveVarList = "c=cvalue;d=dvalue" protected def user = System.getProperty("user.name") protected var warehousePath: File = _ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index c7717d70c996f..d9627eb9790eb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -460,6 +460,7 @@ private[spark] object HiveUtils extends Logging { case (decimal: java.math.BigDecimal, DecimalType()) => // Hive strips trailing zeros so use its toString HiveDecimal.create(decimal).toString + case (other, _ : UserDefinedType[_]) => other.toString case (other, tpe) if primitiveTypes contains tpe => other.toString } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala index d9cf1f361c1d6..61f9179042fe4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala @@ -108,10 +108,7 @@ class ExpressionSQLBuilderSuite extends QueryTest with TestHiveSingleton { } test("window specification") { - val frame = SpecifiedWindowFrame.defaultWindowFrame( - hasOrderSpecification = true, - acceptWindowFrame = true - ) + val frame = SpecifiedWindowFrame(RangeFrame, UnboundedPreceding, CurrentRow) checkSQL( WindowSpecDefinition('a.int :: Nil, Nil, frame), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala index 8697d47e89e89..f2b75e4b23f02 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.SparkConf import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql.QueryTest import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SQLTestUtils} import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader} class HiveUtilsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { @@ -62,4 +62,10 @@ class HiveUtilsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton Thread.currentThread().setContextClassLoader(contextClassLoader) } } + + test("toHiveString correctly handles UDTs") { + val point = new ExamplePoint(50.0, 50.0) + val tpe = new ExamplePointUDT() + assert(HiveUtils.toHiveString((point, tpe)) === "(50.0, 50.0)") + } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index ed2a896033749..aed67a5027433 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -53,21 +53,6 @@ class Checkpoint(ssc: StreamingContext, val checkpointTime: Time) "spark.driver.host", "spark.driver.bindAddress", "spark.driver.port", - "spark.kubernetes.driver.pod.name", - "spark.kubernetes.executor.podNamePrefix", - "spark.kubernetes.initcontainer.executor.configmapname", - "spark.kubernetes.initcontainer.executor.configmapkey", - "spark.kubernetes.initcontainer.downloadJarsResourceIdentifier", - "spark.kubernetes.initcontainer.downloadJarsSecretLocation", - "spark.kubernetes.initcontainer.downloadFilesResourceIdentifier", - "spark.kubernetes.initcontainer.downloadFilesSecretLocation", - "spark.kubernetes.initcontainer.remoteJars", - "spark.kubernetes.initcontainer.remoteFiles", - "spark.kubernetes.mountdependencies.jarsDownloadDir", - "spark.kubernetes.mountdependencies.filesDownloadDir", - "spark.kubernetes.initcontainer.executor.stagingServerSecret.name", - "spark.kubernetes.initcontainer.executor.stagingServerSecret.mountDir", - "spark.kubernetes.executor.limit.cores", "spark.master", "spark.yarn.jars", "spark.yarn.keytab", @@ -81,7 +66,6 @@ class Checkpoint(ssc: StreamingContext, val checkpointTime: Time) val newSparkConf = new SparkConf(loadDefaults = false).setAll(sparkConfPairs) .remove("spark.driver.host") .remove("spark.driver.bindAddress") - .remove("spark.kubernetes.driver.pod.name") .remove("spark.driver.port") val newReloadConf = new SparkConf(loadDefaults = true) propertiesToReload.foreach { prop => diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index 27644a645727c..5d38c56aa5873 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -159,7 +159,9 @@ private[streaming] class ReceiverSupervisorImpl( logDebug(s"Pushed block $blockId in ${(System.currentTimeMillis - time)} ms") val numRecords = blockStoreResult.numRecords val blockInfo = ReceivedBlockInfo(streamId, numRecords, metadataOption, blockStoreResult) - trackerEndpoint.askSync[Boolean](AddBlock(blockInfo)) + if (!trackerEndpoint.askSync[Boolean](AddBlock(blockInfo))) { + throw new SparkException("Failed to add block to receiver tracker.") + } logDebug(s"Reported block $blockId") } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index 6f130c803f310..c74ca1918a81d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -521,7 +521,8 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false if (active) { context.reply(addBlock(receivedBlockInfo)) } else { - throw new IllegalStateException("ReceiverTracker RpcEndpoint shut down.") + context.sendFailure( + new IllegalStateException("ReceiverTracker RpcEndpoint already shut down.")) } } })