Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,16 @@ object SQLConf {
.intConf
.createWithDefault(200)

val THRIFTSERVER_INCREMENTAL_DESERIALIZE =
buildConf("spark.sql.thriftServer.incrementalDeserialize")
.doc("When true, Thrift Server will deserialize result rows incrementally." +
"This feature only has an effect on collection phase, and does not affect scheduling of " +
"partitions. By deserialzing incrementally, the driver of Thrift Server will only use " +
"'memory of serialized rows' + 'memory of the deserialized rows being fetched to the " +
s"client'. Only valid if ${THRIFTSERVER_INCREMENTAL_COLLECT.key} is false.")
.booleanConf
.createWithDefault(false)

// This is used to set the default data source
val DEFAULT_DATA_SOURCE_NAME = buildConf("spark.sql.sources.default")
.doc("The default data source to use in input/output.")
Expand Down
5 changes: 2 additions & 3 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ class Dataset[T] private[sql](

// The deserializer expression which can be used to build a projection and turn rows to objects
// of type T, after collecting rows to the driver side.
private lazy val deserializer =
private[sql] lazy val deserializer =
exprEnc.resolveAndBind(logicalPlan.output, sparkSession.sessionState.analyzer).deserializer

private implicit def classTag = exprEnc.clsTag
Expand Down Expand Up @@ -3162,7 +3162,6 @@ class Dataset[T] private[sql](
}.flatten
files.toSet.toArray
}

////////////////////////////////////////////////////////////////////////////
// For Python API
////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -3342,7 +3341,7 @@ class Dataset[T] private[sql](
* Wrap a Dataset action to track the QueryExecution and time cost, then report to the
* user-registered callback functions.
*/
private def withAction[U](name: String, qe: QueryExecution)(action: SparkPlan => U) = {
private[sql] def withAction[U](name: String, qe: QueryExecution)(action: SparkPlan => U) = {
SQLExecution.withNewExecutionId(sparkSession, qe, Some(name)) {
qe.executedPlan.foreach { plan =>
plan.resetMetrics()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution

import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}

import scala.collection.SeqView
import scala.collection.mutable.ArrayBuffer
import scala.concurrent.ExecutionContext

Expand Down Expand Up @@ -309,6 +310,11 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
(total, rows)
}

private[spark] def executeCollectSeqView(): SeqView[InternalRow, Array[InternalRow]] = {
val countsAndBytes = getByteArrayRdd().collect()
countsAndBytes.view.flatMap(countAndBytes => decodeUnsafeRows(countAndBytes._2))
}

/**
* Runs this query returning the result as an iterator of InternalRow.
*
Expand Down Expand Up @@ -377,6 +383,58 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
}
}

/**
* Runs this query returning SeqView of first `n` rows.
*
* This is modeled to execute decodeUnsafeRows lazily to reduce peak memory usage of
* decoding rows. Only compressed byte arrays consume memory after return.
*/
private[spark] def executeTakeSeqView(n: Int): SeqView[InternalRow, Array[InternalRow]] = {
if (n == 0) {
return Array.empty[InternalRow].view
}

val childRDD = getByteArrayRdd(n)
val encodedBuf = new ArrayBuffer[Array[Byte]]
val totalParts = childRDD.partitions.length
var scannedRowCount = 0L
var partsScanned = 0
while (scannedRowCount < n && partsScanned < totalParts) {
// The number of partitions to try in this iteration. It is ok for this number to be
// greater than totalParts because we actually cap it at totalParts in runJob.
var numPartsToTry = 1L
if (partsScanned > 0) {
// If we didn't find any rows after the previous iteration, quadruple and retry.
// Otherwise, interpolate the number of partitions we need to try, but overestimate
// it by 50%. We also cap the estimation in the end.
val limitScaleUpFactor = Math.max(sqlContext.conf.limitScaleUpFactor, 2)
if (scannedRowCount == 0) {
numPartsToTry = partsScanned * limitScaleUpFactor
} else {
val left = n - scannedRowCount
// As left > 0, numPartsToTry is always >= 1
numPartsToTry = Math.ceil(1.5 * left * partsScanned / scannedRowCount).toInt
numPartsToTry = Math.min(numPartsToTry, partsScanned * limitScaleUpFactor)
}
}

val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt)
val sc = sqlContext.sparkContext
val res = sc.runJob(childRDD, (it: Iterator[(Long, Array[Byte])]) =>
if (it.hasNext) it.next() else (0L, Array.empty[Byte]), p)

encodedBuf ++= res.map(_._2)
scannedRowCount += res.map(_._1).sum
partsScanned += p.size
}

if (scannedRowCount > n) {
encodedBuf.toArray.view.flatMap(decodeUnsafeRows).take(n)
} else {
encodedBuf.toArray.view.flatMap(decodeUnsafeRows)
}
}

protected def newMutableProjection(
expressions: Seq[Expression],
inputSchema: Seq[Attribute],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.command

import java.util.UUID

import scala.collection.SeqView

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
Expand Down Expand Up @@ -78,6 +80,9 @@ case class ExecutedCommandExec(cmd: RunnableCommand) extends LeafExecNode {

override def executeCollect(): Array[InternalRow] = sideEffectResult.toArray

override private[spark] def executeCollectSeqView(): SeqView[InternalRow, Array[InternalRow]] =
executeCollect().view

override def executeToIterator: Iterator[InternalRow] = sideEffectResult.toIterator

override def executeTake(limit: Int): Array[InternalRow] = sideEffectResult.take(limit).toArray
Expand Down Expand Up @@ -115,6 +120,9 @@ case class DataWritingCommandExec(cmd: DataWritingCommand, child: SparkPlan)

override def executeCollect(): Array[InternalRow] = sideEffectResult.toArray

override private[spark] def executeCollectSeqView(): SeqView[InternalRow, Array[InternalRow]] =
executeCollect().view

override def executeToIterator: Iterator[InternalRow] = sideEffectResult.toIterator

override def executeTake(limit: Int): Array[InternalRow] = sideEffectResult.take(limit).toArray
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.execution

import scala.collection.SeqView

import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
import org.apache.spark.sql.catalyst.InternalRow
Expand All @@ -37,6 +39,8 @@ case class CollectLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode
override def output: Seq[Attribute] = child.output
override def outputPartitioning: Partitioning = SinglePartition
override def executeCollect(): Array[InternalRow] = child.executeTake(limit)
override private[spark] def executeCollectSeqView(): SeqView[InternalRow, Array[InternalRow]] =
child.executeTakeSeqView(limit)
private val serializer: Serializer = new UnsafeRowSerializer(child.output.size)
private lazy val writeMetrics =
SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import java.util.{Arrays, Map => JMap, UUID}
import java.util.concurrent.RejectedExecutionException

import scala.collection.JavaConverters._
import scala.collection.SeqView
import scala.collection.mutable.ArrayBuffer
import scala.util.control.NonFatal

Expand All @@ -35,6 +36,7 @@ import org.apache.hive.service.cli.session.HiveSession
import org.apache.spark.SparkContext
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, Row => SparkRow, SQLContext}
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection
import org.apache.spark.sql.execution.HiveResult
import org.apache.spark.sql.execution.command.SetCommand
import org.apache.spark.sql.internal.SQLConf
Expand All @@ -55,7 +57,7 @@ private[hive] class SparkExecuteStatementOperation(
// We cache the returned rows to get iterators again in case the user wants to use FETCH_FIRST.
// This is only used when `spark.sql.thriftServer.incrementalCollect` is set to `false`.
// In case of `true`, this will be `None` and FETCH_FIRST will trigger re-execution.
private var resultList: Option[Array[SparkRow]] = _
private var resultList: Option[Seq[SparkRow]] = _

private var iter: Iterator[SparkRow] = _
private var dataTypes: Array[DataType] = _
Expand Down Expand Up @@ -122,7 +124,12 @@ private[hive] class SparkExecuteStatementOperation(
result.toLocalIterator.asScala
} else {
if (resultList.isEmpty) {
resultList = Some(result.collect())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you keep the current behavior? Then, please implement a SeqView iteration model turned on/off by a new option.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, I'll change this feature as boolean. Thank you for review.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

resultList = if (sqlContext.getConf(
SQLConf.THRIFTSERVER_INCREMENTAL_DESERIALIZE.key).toBoolean) {
Some(result.collectAsSeqView())
} else {
Some(result.collect())
}
}
resultList.get.iterator
}
Expand Down Expand Up @@ -245,7 +252,12 @@ private[hive] class SparkExecuteStatementOperation(
resultList = None
result.toLocalIterator.asScala
} else {
resultList = Some(result.collect())
resultList = if (sqlContext.getConf(
SQLConf.THRIFTSERVER_INCREMENTAL_DESERIALIZE.key).toBoolean) {
Some(result.collectAsSeqView())
} else {
Some(result.collect())
}
resultList.get.iterator
}
}
Expand Down Expand Up @@ -291,6 +303,23 @@ private[hive] class SparkExecuteStatementOperation(
sqlContext.sparkContext.cancelJobGroup(statementId)
}
}

private implicit class DataFrameWrapper(df: DataFrame) {
/**
* Returns a SeqView that contains all rows in this Dataset.
*
* The SeqView will consume as much memory as the total size of serialized results which can be
* limited with the config 'spark.driver.maxResultSize'. Rows are deserialized when iterating
* rows with iterator of returned SeqView.
*/
def collectAsSeqView(): SeqView[SparkRow, Array[SparkRow]] =
df.withAction("collectAsSeqView", df.queryExecution) { plan =>
val objProj = GenerateSafeProjection.generate(df.deserializer :: Nil)
plan.executeCollectSeqView().map(row =>
objProj(row).get(0, null).asInstanceOf[SparkRow]
).asInstanceOf[SeqView[SparkRow, Array[SparkRow]]]
}
}
}

object SparkExecuteStatementOperation {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,66 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest {
assert(resultSet.getString(1) === "4.56")
}
}

test("SPARK-25224 Checks config of incrementalDeserialize.") {
import org.apache.spark.sql.internal.SQLConf
withCLIServiceClient { client =>
val user = System.getProperty("user.name")
val sessionHandle = client.openSession(user, "")
withJdbcStatement("test_25224") { statement =>
statement.execute("CREATE TABLE test_25224(key INT, val STRING)")

val initQueries = Seq(s"SET ${SQLConf.THRIFTSERVER_INCREMENTAL_DESERIALIZE.key}=true",
s"LOAD DATA LOCAL INPATH '${TestData.smallKvWithNull}' OVERWRITE INTO TABLE test_25224")
val confOverlay = new java.util.HashMap[java.lang.String, java.lang.String]
initQueries.foreach { query =>
client.executeStatement(sessionHandle, query, confOverlay)
}

val operationHandle = client.executeStatement(
sessionHandle,
"SELECT * FROM test_25224 WHERE key IS NULL",
confOverlay)

assertResult(5, "Fetching result first time with fetch_next with row by row") {
var totalRowCount = 0
var fetchMore = true
while (fetchMore) {
// FETCH_NEXT with maxRows = 1
val rowCount = client.fetchResults(
operationHandle,
FetchOrientation.FETCH_NEXT,
1,
FetchType.QUERY_OUTPUT).numRows()
if (rowCount <= 0) fetchMore = false
else totalRowCount += rowCount
}
totalRowCount
}

assertResult(5, "Repeat fetching result from fetch_first") {
// Repeating FETCH_FIRST with maxRows = 1000
val rows_first = client.fetchResults(
operationHandle,
FetchOrientation.FETCH_FIRST,
1000,
FetchType.QUERY_OUTPUT)

rows_first.numRows()
}

statement.execute(s"SET ${SQLConf.THRIFTSERVER_INCREMENTAL_DESERIALIZE.key}=true")
statement.setFetchSize(2)
val rs = statement.executeQuery("SELECT * FROM test_25224 WHERE key IS NOT NULL LIMIT 3")
Seq(238, 311, 255).foreach { key =>
rs.next()
assert(rs.getInt(1) === key)
}
assert(!rs.next())
rs.close()
}
}
}
}

class SingleSessionSuite extends HiveThriftJdbcTest {
Expand Down