Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft][CORE][SQL] Add debugging operator to identify skew in datasets inline #46490

Draft
wants to merge 22 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.optimizer.OptimizeUpdateFields
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
Expand Down Expand Up @@ -352,7 +353,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
Batch("HandleSpecialCommand", Once,
HandleSpecialCommand),
Batch("Remove watermark for batch query", Once,
EliminateEventTimeWatermark)
EliminateEventTimeWatermark),
Batch("Debug", Once, DebugInlineColumnsCountInference)
)

/**
Expand Down Expand Up @@ -4126,3 +4128,29 @@ object RemoveTempResolvedColumn extends Rule[LogicalPlan] {
}
}
}

/**
* This infers the columns to use for [[DebugInlineColumnsCount]] when possible.
* For joins, it will use the join key columns so the application code does not need
* to specify it for both the inputs and output.
*/
object DebugInlineColumnsCountInference extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = {
plan.transformUp {
case DebugInlineColumnsCount(j @ ExtractEquiJoinKeys(_, leftKeys, rightKeys, _, _,
_, _, _), sampleColumns, maxKeys) if sampleColumns.isEmpty =>
val left = wrapJoinInput(j.left, leftKeys, maxKeys)
val right = wrapJoinInput(j.right, rightKeys, maxKeys)

val joinWithInputDebug = j.withNewChildren(Seq(left, right))
DebugInlineColumnsCount(joinWithInputDebug, leftKeys, maxKeys)
}
}

private def wrapJoinInput(
plan: LogicalPlan,
sampleKeys: Seq[Expression], maxKeys: Int) = plan match {
case d: DebugInlineColumnsCount => d
case _ => DebugInlineColumnsCount(plan, sampleKeys, maxKeys)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}


case class DebugInlineColumnsCount(
child: LogicalPlan,
sampleColumns: Seq[Expression],
maxKeys: Int
) extends UnaryNode {

override protected def withNewChildInternal(newChild: LogicalPlan): DebugInlineColumnsCount =
copy(child = newChild)

override def output: Seq[Attribute] = child.output
}
Original file line number Diff line number Diff line change
Expand Up @@ -3969,6 +3969,12 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val MAX_INLINE_COLUMN_COUNT_KEYS = buildConf("spark.sql.debug.maxInlineColumnCountKeys")
.doc("Maximum number of keys to maintain when counting the frequency of each key.")
.version("4.0.0")
.intConf
.createWithDefault(3)

val MAX_TO_STRING_FIELDS = buildConf("spark.sql.debug.maxToStringFields")
.doc("Maximum number of fields of sequence-like entries can be converted to strings " +
"in debug output. Any elements beyond the limit will be dropped and replaced by a" +
Expand Down Expand Up @@ -5721,6 +5727,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
def nameNonStructGroupingKeyAsValue: Boolean =
getConf(SQLConf.NAME_NON_STRUCT_GROUPING_KEY_AS_VALUE)

def maxInlineColumnCountKeys: Int = getConf(SQLConf.MAX_INLINE_COLUMN_COUNT_KEYS)

override def maxToStringFields: Int = getConf(SQLConf.MAX_TO_STRING_FIELDS)

def maxPlanStringLength: Int = getConf(SQLConf.MAX_PLAN_STRING_LENGTH).toInt
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.sql.execution.adaptive.LogicalQueryStageStrategy
import org.apache.spark.sql.execution.command.v2.V2CommandStrategy
import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, FileSourceStrategy}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Strategy
import org.apache.spark.sql.execution.debug.DebugPlanner

class SparkPlanner(val session: SparkSession, val experimentalMethods: ExperimentalMethods)
extends SparkStrategies with SQLConfHelper {
Expand All @@ -47,7 +48,8 @@ class SparkPlanner(val session: SparkSession, val experimentalMethods: Experimen
JoinSelection ::
InMemoryScans ::
SparkScripts ::
BasicOperators :: Nil)
BasicOperators ::
DebugPlanner :: Nil)

/**
* Override to add extra planning strategies to the planner. These strategies are tried after
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@

package org.apache.spark.sql.execution

import java.io.StringWriter
import java.util.Collections

import scala.collection.mutable
import scala.jdk.CollectionConverters._
import scala.util.control.NonFatal

Expand All @@ -27,17 +29,22 @@ import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.expressions.codegen.{ByteCodeStats, CodeFormatter, CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JSONOptions}
import org.apache.spark.sql.catalyst.plans.logical.{DebugInlineColumnsCount, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.trees.TreeNodeRef
import org.apache.spark.sql.catalyst.util.StringConcat
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData, StringConcat}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStageExec}
import org.apache.spark.sql.execution.streaming.{StreamExecution, StreamingQueryWrapper}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.StreamingQuery
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType, VariantType}
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.{AccumulatorV2, LongAccumulator}
import org.apache.spark.unsafe.types.VariantVal
import org.apache.spark.util.{AccumulatorV2, LongAccumulator, Utils}

/**
* Contains methods for debugging query execution.
Expand Down Expand Up @@ -198,6 +205,19 @@ package object debug {
def debugCodegen(): Unit = {
debugPrint(codegenString(query.queryExecution.executedPlan))
}

/**
* Counts the occurrence of values for the specified column combinations and periodically
* prints the results to stdout. Results will not have perfect accuracy because it only
* maintains the top K values. This is useful for identifying which values are creating skew
* in a column.
* @param columns The combination of columns to count the value occurrences for
*/
def inlineColumnsCount(columns: Column *): Dataset[_] = {
val maxKeys = SQLConf.get.maxInlineColumnCountKeys
val plan = DebugInlineColumnsCount(query.logicalPlan, columns.map(_.expr), maxKeys)
Dataset.ofRows(query.sparkSession, plan)
}
}

implicit class DebugStreamQuery(query: StreamingQuery) extends Logging {
Expand All @@ -206,7 +226,6 @@ package object debug {
}
}


class SetAccumulator[T] extends AccumulatorV2[T, java.util.Set[T]] {
private val _set = Collections.synchronizedSet(new java.util.HashSet[T]())

Expand Down Expand Up @@ -295,4 +314,145 @@ package object debug {
override protected def withNewChildInternal(newChild: SparkPlan): DebugExec =
copy(child = newChild)
}

case class DebugInlineColumnsCountExec(
child: SparkPlan,
sampleColumns: Seq[Expression],
maxKeys: Int
) extends UnaryExecNode {

private val jsonOptions = new JSONOptions(Map.empty[String, String], "UTC")

private val accumulator = new DebugAccumulator(maxKeys)
accumulator.register(
session.sparkContext,
Some(s"${child.nodeName} top values for ${sampleColumns.mkString(",")}"))

override protected def withNewChildInternal(newChild: SparkPlan): DebugInlineColumnsCountExec =
copy(child = newChild)

override protected def doExecute(): RDD[InternalRow] = {
val exprs = bindReferences[Expression](sampleColumns, child.output)

child.execute().mapPartitions { iter =>
iter.map { row =>
val sampleVals = exprs.map { expr => valToString(expr.dataType, expr.eval(row)) }
accumulator.add(sampleVals.mkString(","))
row
}
}
}

private def valToString(dataType: DataType, value: Any): String = {
Option(value).map { v =>
dataType match {
case _: StructType | _: ArrayType | _: MapType | _: VariantType =>
Utils.tryWithResource(new StringWriter()) { writer =>
val gen = new JacksonGenerator(dataType, writer, jsonOptions)

dataType match {
case _: StructType =>
gen.write(v.asInstanceOf[InternalRow])
case _: ArrayType =>
gen.write(v.asInstanceOf[ArrayData])
case _: MapType =>
gen.write(v.asInstanceOf[MapData])
case _: VariantType =>
gen.write(v.asInstanceOf[VariantVal])
}

gen.flush()
writer.toString
}
case _ => v.toString
}
}.orNull
}

override def output: Seq[Attribute] = child.output
}

object DebugPlanner extends SparkStrategy {
override def apply(plan: LogicalPlan): Seq[SparkPlan] = {
plan match {
case DebugInlineColumnsCount(child, sampleColumns, maxKeys) =>
DebugInlineColumnsCountExec(planLater(child), sampleColumns, maxKeys) :: Nil
case _ => Nil
}
}
}

class DebugAccumulator(maxKeys: Int) extends AccumulatorV2[String, Map[String, Long]] {
private val keyToCount = mutable.Map.empty[String, Long]
private val countToKeys = mutable.TreeMap.empty[Long, mutable.Set[String]]

/**
* Returns if this accumulator is zero value or not. e.g. for a counter accumulator, 0 is zero
* value; for a list accumulator, Nil is zero value.
*/
override def isZero: Boolean = this.synchronized { keyToCount.isEmpty }

/**
* Creates a new copy of this accumulator.
*/
override def copy(): DebugAccumulator = {
val newAcc = new DebugAccumulator(maxKeys)
newAcc.merge(this)
newAcc
}

/**
* Resets this accumulator, which is zero value. i.e. call `isZero` must
* return true.
*/
override def reset(): Unit = this.synchronized { keyToCount.clear() }

/**
* Takes the inputs and accumulates.
*/
override def add(v: String): Unit = add(v, 1)

private def add(v: String, add: Long): Unit = this.synchronized {
val count = keyToCount.getOrElse(v, 0L) + add
keyToCount.put(v, count)

val keys = countToKeys.getOrElseUpdate(count, mutable.Set[String]())
keys.add(v)

if (keyToCount.size > maxKeys) {
dropSmallest()
}
}

private def dropSmallest(): Unit = {
val keys = countToKeys.head._2
val keyToDrop = keys.head

keys.remove(keyToDrop)
keyToCount.remove(keyToDrop)

if (keys.isEmpty) {
countToKeys.remove(countToKeys.head._1)
}
}

/**
* Merges another same-type accumulator into this one and update its state, i.e. this should be
* merge-in-place.
*/
override def merge(other: AccumulatorV2[String, Map[String, Long]]): Unit = this.synchronized {
other match {
case o: DebugAccumulator => o.keyToCount.foreach { case (k, v) => add(k, v) }
case _ =>
throw new UnsupportedOperationException(s"Cannot merge with ${other.getClass.getName}")
}
}

/**
* Defines the current value of this accumulator
*/
override def value: Map[String, Long] = this.synchronized {
keyToCount.toMap
}
}
}