Skip to content

Commit

Permalink
[WIP][Experiment] Get rid of redundant key part from value for statef…
Browse files Browse the repository at this point in the history
…ul aggregation

* add option to configure enabling new feature
* modify code to respect new option (turning on/off feature)
* modify tests to run tests with both on/off
* Add guard in OffsetSeqMetadata to prevent modifying option after executing query
  • Loading branch information
HeartSaVioR committed Jul 8, 2018
1 parent 79c6689 commit 378ce2a
Show file tree
Hide file tree
Showing 4 changed files with 516 additions and 343 deletions.
Expand Up @@ -825,6 +825,16 @@ object SQLConf {
.intConf
.createWithDefault(100)

val ADVANCED_REMOVE_REDUNDANT_IN_STATEFUL_AGGREGATION =
buildConf("spark.sql.streaming.advanced.removeRedundantInStatefulAggregation")
.internal()
// FIXME: this must be well written later
.doc("ADVANCED option: When true, stateful aggregation tries to remove redundant data " +
"between key and value in state. Enabling this option helps minimizing state size, " +
"but requires couple of maybe-expensive operations.")
.booleanConf
.createWithDefault(false)

val UNSUPPORTED_OPERATION_CHECK_ENABLED =
buildConf("spark.sql.streaming.unsupportedOperationCheck")
.internal()
Expand Down Expand Up @@ -1548,6 +1558,9 @@ class SQLConf extends Serializable with Logging {
def advancedPartitionPredicatePushdownEnabled: Boolean =
getConf(ADVANCED_PARTITION_PREDICATE_PUSHDOWN)

def advancedRemoveRedundantInStatefulAggregation: Boolean =
getConf(ADVANCED_REMOVE_REDUNDANT_IN_STATEFUL_AGGREGATION)

def fallBackToHdfsForStatsEnabled: Boolean = getConf(ENABLE_FALL_BACK_TO_HDFS_FOR_STATS)

def preferSortMergeJoin: Boolean = getConf(PREFER_SORTMERGEJOIN)
Expand Down
Expand Up @@ -22,7 +22,8 @@ import org.json4s.jackson.Serialization

import org.apache.spark.internal.Logging
import org.apache.spark.sql.RuntimeConfig
import org.apache.spark.sql.internal.SQLConf.{SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS}
import org.apache.spark.sql.internal.SQLConf.{ADVANCED_REMOVE_REDUNDANT_IN_STATEFUL_AGGREGATION,
SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS}

/**
* An ordered collection of offsets, used to track the progress of processing data from one or more
Expand Down Expand Up @@ -86,7 +87,8 @@ case class OffsetSeqMetadata(

object OffsetSeqMetadata extends Logging {
private implicit val format = Serialization.formats(NoTypeHints)
private val relevantSQLConfs = Seq(SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS)
private val relevantSQLConfs = Seq(SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS,
ADVANCED_REMOVE_REDUNDANT_IN_STATEFUL_AGGREGATION)

def apply(json: String): OffsetSeqMetadata = Serialization.read[OffsetSeqMetadata](json)

Expand Down
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, Predicate}
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateUnsafeRowJoiner, Predicate}
import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning}
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
Expand Down Expand Up @@ -204,30 +204,64 @@ case class StateStoreRestoreExec(
child: SparkPlan)
extends UnaryExecNode with StateStoreReader {

val removeRedundant: Boolean = sqlContext.conf.advancedRemoveRedundantInStatefulAggregation
if (removeRedundant) {
log.info("Advanced option removeRedundantInStatefulAggregation activated!")
}

val valueExpressions: Seq[Attribute] = if (removeRedundant) {
child.output.diff(keyExpressions)
} else {
child.output
}
val keyValueJoinedExpressions: Seq[Attribute] = keyExpressions ++ valueExpressions
val needToProjectToRestoreValue: Boolean = keyValueJoinedExpressions != child.output

override protected def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")

child.execute().mapPartitionsWithStateStore(
getStateInfo,
keyExpressions.toStructType,
child.output.toStructType,
valueExpressions.toStructType,
indexOrdinal = None,
sqlContext.sessionState,
Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) =>
val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output)
val joiner = GenerateUnsafeRowJoiner.create(StructType.fromAttributes(keyExpressions),
StructType.fromAttributes(valueExpressions))
val restoreValueProject = GenerateUnsafeProjection.generate(
keyValueJoinedExpressions, child.output)

val hasInput = iter.hasNext
if (!hasInput && keyExpressions.isEmpty) {
// If our `keyExpressions` are empty, we're getting a global aggregation. In that case
// the `HashAggregateExec` will output a 0 value for the partial merge. We need to
// restore the value, so that we don't overwrite our state with a 0 value, but rather
// merge the 0 with existing state.
// In this case the value should represent origin row, so no need to restore.
store.iterator().map(_.value)
} else {
iter.flatMap { row =>
val key = getKey(row)
val savedState = store.get(key)
val restoredRow = if (removeRedundant) {
if (savedState == null) {
savedState
} else {
val joinedRow = joiner.join(key, savedState)
if (needToProjectToRestoreValue) {
restoreValueProject(joinedRow)
} else {
joinedRow
}
}
} else {
savedState
}

numOutputRows += 1
Option(savedState).toSeq :+ row
Option(restoredRow).toSeq :+ row
}
}
}
Expand Down Expand Up @@ -257,6 +291,19 @@ case class StateStoreSaveExec(
child: SparkPlan)
extends UnaryExecNode with StateStoreWriter with WatermarkSupport {

val removeRedundant: Boolean = sqlContext.conf.advancedRemoveRedundantInStatefulAggregation
if (removeRedundant) {
log.info("Advanced option removeRedundantInStatefulAggregation activated!")
}

val valueExpressions: Seq[Attribute] = if (removeRedundant) {
child.output.diff(keyExpressions)
} else {
child.output
}
val keyValueJoinedExpressions: Seq[Attribute] = keyExpressions ++ valueExpressions
val needToProjectToRestoreValue: Boolean = keyValueJoinedExpressions != child.output

override protected def doExecute(): RDD[InternalRow] = {
metrics // force lazy init at driver
assert(outputMode.nonEmpty,
Expand All @@ -265,11 +312,17 @@ case class StateStoreSaveExec(
child.execute().mapPartitionsWithStateStore(
getStateInfo,
keyExpressions.toStructType,
child.output.toStructType,
valueExpressions.toStructType,
indexOrdinal = None,
sqlContext.sessionState,
Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) =>
val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output)
val getValue = GenerateUnsafeProjection.generate(valueExpressions, child.output)
val joiner = GenerateUnsafeRowJoiner.create(StructType.fromAttributes(keyExpressions),
StructType.fromAttributes(valueExpressions))
val restoreValueProject = GenerateUnsafeProjection.generate(
keyValueJoinedExpressions, child.output)

val numOutputRows = longMetric("numOutputRows")
val numUpdatedStateRows = longMetric("numUpdatedStateRows")
val allUpdatesTimeMs = longMetric("allUpdatesTimeMs")
Expand All @@ -283,7 +336,13 @@ case class StateStoreSaveExec(
while (iter.hasNext) {
val row = iter.next().asInstanceOf[UnsafeRow]
val key = getKey(row)
store.put(key, row)
val value = if (removeRedundant) {
// TODO: could we do better to remove overhead for "copying" inside store?
getValue(row)
} else {
row
}
store.put(key, value)
numUpdatedStateRows += 1
}
}
Expand All @@ -294,7 +353,18 @@ case class StateStoreSaveExec(
setStoreMetrics(store)
store.iterator().map { rowPair =>
numOutputRows += 1
rowPair.value

if (removeRedundant) {
val joinedRow = joiner.join(rowPair.key, rowPair.value)
if (needToProjectToRestoreValue) {
restoreValueProject(joinedRow)
} else {
joinedRow
}
} else {
rowPair.value
}

}

// Update and output only rows being evicted from the StateStore
Expand All @@ -305,7 +375,13 @@ case class StateStoreSaveExec(
while (filteredIter.hasNext) {
val row = filteredIter.next().asInstanceOf[UnsafeRow]
val key = getKey(row)
store.put(key, row)
val value = if (removeRedundant) {
// TODO: could we do better to remove overhead for "copying" inside store?
getValue(row)
} else {
row
}
store.put(key, value)
numUpdatedStateRows += 1
}
}
Expand All @@ -320,7 +396,17 @@ case class StateStoreSaveExec(
val rowPair = rangeIter.next()
if (watermarkPredicateForKeys.get.eval(rowPair.key)) {
store.remove(rowPair.key)
removedValueRow = rowPair.value

if (removeRedundant) {
val joinedRow = joiner.join(rowPair.key, rowPair.value)
removedValueRow = if (needToProjectToRestoreValue) {
restoreValueProject(joinedRow)
} else {
joinedRow
}
} else {
removedValueRow = rowPair.value
}
}
}
if (removedValueRow == null) {
Expand Down Expand Up @@ -353,7 +439,13 @@ case class StateStoreSaveExec(
if (baseIterator.hasNext) {
val row = baseIterator.next().asInstanceOf[UnsafeRow]
val key = getKey(row)
store.put(key, row)
val value = if (removeRedundant) {
// TODO: could we do better to remove overhead for "copying" inside store?
getValue(row)
} else {
row
}
store.put(key, value)
numOutputRows += 1
numUpdatedStateRows += 1
row
Expand Down

0 comments on commit 378ce2a

Please sign in to comment.