Skip to content

Commit

Permalink
[GLUTEN-5696] Add preprojection support for ArrowEvalPythonExec
Browse files Browse the repository at this point in the history
  • Loading branch information
yma11 committed May 12, 2024
1 parent 4899ea5 commit b40e1d5
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -860,4 +860,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
override def genPreProjectForGenerate(generate: GenerateExec): SparkPlan = generate

override def genPostProjectForGenerate(generate: GenerateExec): SparkPlan = generate

override def genPreProjectForArrowEvalPythonExec(
arrowEvalPythonExec: ArrowEvalPythonExec): SparkPlan = arrowEvalPythonExec
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.gluten.substrait.expression.{ExpressionBuilder, ExpressionNode
import org.apache.gluten.vectorized.{ColumnarBatchSerializer, ColumnarBatchSerializeResult}

import org.apache.spark.{ShuffleDependency, SparkException}
import org.apache.spark.api.python.ColumnarArrowEvalPythonExec
import org.apache.spark.api.python.{ColumnarArrowEvalPythonExec, PullOutArrowEvalPythonPreProjectHelper}
import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{GenShuffleWriterParameters, GlutenShuffleWriterWrapper}
Expand All @@ -53,6 +53,7 @@ import org.apache.spark.sql.execution.datasources.FileFormat
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BuildSideRelation, HashedRelationBroadcastMode}
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.python.ArrowEvalPythonExec
import org.apache.spark.sql.execution.utils.ExecUtil
import org.apache.spark.sql.expression.{UDFExpression, UDFResolver, UserDefinedAggregateFunction}
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -835,6 +836,11 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
PullOutGenerateProjectHelper.pullOutPostProject(generate)
}

override def genPreProjectForArrowEvalPythonExec(
arrowEvalPythonExec: ArrowEvalPythonExec): SparkPlan = {
PullOutArrowEvalPythonPreProjectHelper.pullOutPreProject(arrowEvalPythonExec)
}

override def maybeCollapseTakeOrderedAndProject(plan: SparkPlan): SparkPlan = {
// This to-top-n optimization assumes exchange operators were already placed in input plan.
plan.transformUp {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,18 @@
package org.apache.spark.api.python

import org.apache.gluten.columnarbatch.ColumnarBatches
import org.apache.gluten.exception.GlutenException
import org.apache.gluten.extension.GlutenPlan
import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators
import org.apache.gluten.utils.Iterators
import org.apache.gluten.utils.{Iterators, PullOutProjectHelper}
import org.apache.gluten.vectorized.ArrowWritableColumnVector

import org.apache.spark.{ContextAwareIterator, SparkEnv, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.python.{BasePythonRunnerShim, EvalPythonExec, PythonUDFRunner}
import org.apache.spark.sql.execution.{ProjectExec, SparkPlan}
import org.apache.spark.sql.execution.python.{ArrowEvalPythonExec, BasePythonRunnerShim, EvalPythonExec, PythonUDFRunner}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, StructField, StructType}
import org.apache.spark.sql.utils.{SparkArrowUtil, SparkSchemaUtil, SparkVectorUtil}
Expand All @@ -41,6 +42,7 @@ import java.io.{DataInputStream, DataOutputStream}
import java.net.Socket
import java.util.concurrent.atomic.AtomicBoolean

import scala.collection.{mutable, Seq}
import scala.collection.mutable.ArrayBuffer

class ColumnarArrowPythonRunner(
Expand Down Expand Up @@ -207,7 +209,6 @@ case class ColumnarArrowEvalPythonExec(
extends EvalPythonExec
with GlutenPlan {
override def supportsColumnar: Boolean = true
// TODO: add additional projection support by pre-project
// FIXME: incorrect metrics updater

override protected def evaluate(
Expand All @@ -221,6 +222,7 @@ case class ColumnarArrowEvalPythonExec(
}

private val sessionLocalTimeZone = conf.sessionLocalTimeZone

private def getPythonRunnerConfMap(conf: SQLConf): Map[String, String] = {
val timeZoneConf = Seq(SQLConf.SESSION_LOCAL_TIMEZONE.key -> conf.sessionLocalTimeZone)
val pandasColsByName = Seq(
Expand All @@ -231,6 +233,7 @@ case class ColumnarArrowEvalPythonExec(
conf.arrowSafeTypeConversion.toString)
Map(timeZoneConf ++ pandasColsByName ++ arrowSafeTypeCheck: _*)
}

private val pythonRunnerConf = getPythonRunnerConfMap(conf)

protected def evaluateColumnar(
Expand Down Expand Up @@ -279,16 +282,30 @@ case class ColumnarArrowEvalPythonExec(
iter =>
val context = TaskContext.get()
val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip
// flatten all the arguments
// We only write the referred cols by UDFs to python worker. So we need
// get corresponding offsets
val allInputs = new ArrayBuffer[Expression]
val dataTypes = new ArrayBuffer[DataType]
val originalOffsets = new ArrayBuffer[Int]
val argOffsets = inputs.map {
input =>
input.map {
e =>
if (allInputs.exists(_.semanticEquals(e))) {
allInputs.indexWhere(_.semanticEquals(e))
} else {
if (!e.isInstanceOf[AttributeReference]) {
throw new GlutenException(
"ColumnarArrowEvalPythonExec should only has [AttributeReference] inputs.")
}
var offset: Int = -1
offset = child.output.indexWhere(
_.exprId.equals(e.asInstanceOf[AttributeReference].exprId))
if (offset == -1) {
throw new GlutenException(
"ColumnarArrowEvalPythonExec can't find referred input col.")
}
originalOffsets += offset
allInputs += e
dataTypes += e.dataType
allInputs.length - 1
Expand All @@ -299,15 +316,20 @@ case class ColumnarArrowEvalPythonExec(
case (dt, i) =>
StructField(s"_$i", dt)
}.toSeq)

val contextAwareIterator = new ContextAwareIterator(context, iter)
val inputCbCache = new ArrayBuffer[ColumnarBatch]()
val inputBatchIter = contextAwareIterator.map {
inputCb =>
ColumnarBatches.ensureLoaded(ArrowBufferAllocators.contextInstance, inputCb)
// 0. cache input for later merge
ColumnarBatches.retain(inputCb)
// 0. cache input for later merge
inputCbCache += inputCb
inputCb
var colsForEval = new ArrayBuffer[ColumnVector]()
for (i <- originalOffsets) {
colsForEval += inputCb.column(i)
}
new ColumnarBatch(colsForEval.toArray, inputCb.numRows())
}

val outputColumnarBatchIterator =
Expand Down Expand Up @@ -335,6 +357,65 @@ case class ColumnarArrowEvalPythonExec(
.create()
}
}

override protected def withNewChildInternal(newChild: SparkPlan): ColumnarArrowEvalPythonExec =
copy(udfs, resultAttrs, newChild)
}

object PullOutArrowEvalPythonPreProjectHelper extends PullOutProjectHelper {
private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = {
udf.children match {
case Seq(u: PythonUDF) =>
val (chained, children) = collectFunctions(u)
(ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children)
case children =>
(ChainedPythonFunctions(Seq(udf.func)), udf.children)
}
}

private def rewriteUDF(
udf: PythonUDF,
expressionMap: mutable.HashMap[Expression, NamedExpression]): PythonUDF = {
udf.children match {
case Seq(u: PythonUDF) =>
udf
.withNewChildren(udf.children.toIndexedSeq.map {
func => rewriteUDF(func.asInstanceOf[PythonUDF], expressionMap)
})
.asInstanceOf[PythonUDF]
case children =>
val newUDFChildren = udf.children.map {
case literal: Literal => literal
case other => replaceExpressionWithAttribute(other, expressionMap)
}
udf.withNewChildren(newUDFChildren).asInstanceOf[PythonUDF]
}
}

def pullOutPreProject(arrowEvalPythonExec: ArrowEvalPythonExec): SparkPlan = {
// pull out preproject
val (_, inputs) = arrowEvalPythonExec.udfs.map(collectFunctions).unzip
val expressionMap = new mutable.HashMap[Expression, NamedExpression]()
// flatten all the arguments
val allInputs = new ArrayBuffer[Expression]
for (input <- inputs) {
input.map {
e =>
if (!allInputs.exists(_.semanticEquals(e))) {
allInputs += e
replaceExpressionWithAttribute(e, expressionMap)
}
}
}
if (!expressionMap.isEmpty) {
// Need preproject.
val preProject = ProjectExec(
eliminateProjectList(arrowEvalPythonExec.child.outputSet, expressionMap.values.toSeq),
arrowEvalPythonExec.child)
val newUDFs = arrowEvalPythonExec.udfs.map(f => rewriteUDF(f, expressionMap))
arrowEvalPythonExec.copy(udfs = newUDFs, child = preProject)
} else {
arrowEvalPythonExec
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class ArrowEvalPythonExecSuite extends WholeStageTransformerSuite {
.set("spark.executor.cores", "1")
}

test("arrow_udf test") {
test("arrow_udf test: without projection") {
lazy val base =
Seq(("1", 1), ("1", 2), ("2", 1), ("2", 2), ("3", 1), ("3", 2), ("0", 1), ("3", 0))
.toDF("a", "b")
Expand All @@ -58,4 +58,46 @@ class ArrowEvalPythonExecSuite extends WholeStageTransformerSuite {
checkSparkOperatorMatch[ColumnarArrowEvalPythonExec](df2)
checkAnswer(df2, expected)
}

testWithSpecifiedSparkVersion("arrow_udf test: with unrelated projection", Some("3.3")) {
lazy val base =
Seq(("1", 1), ("1", 2), ("2", 1), ("2", 2), ("3", 1), ("3", 2), ("0", 1), ("3", 0))
.toDF("a", "b")
lazy val expected = Seq(
("1", 1, "1", 2),
("1", 2, "1", 4),
("2", 1, "2", 2),
("2", 2, "2", 4),
("3", 1, "3", 2),
("3", 2, "3", 4),
("0", 1, "0", 2),
("3", 0, "3", 0)
).toDF("a", "b", "p_a", "d_b")

val df2 = base.withColumns(Map("p_a" -> pyarrowTestUDF(base("a")), "d_b" -> base("b") * 2))
checkSparkOperatorMatch[ColumnarArrowEvalPythonExec](df2)
checkAnswer(df2, expected)
}

testWithSpecifiedSparkVersion("arrow_udf test: with preprojection", Some("3.3")) {
lazy val base =
Seq(("1", 1), ("1", 2), ("2", 1), ("2", 2), ("3", 1), ("3", 2), ("0", 1), ("3", 0))
.toDF("a", "b")
lazy val expected = Seq(
("1", 1, 2, "1", 2),
("1", 2, 4, "1", 4),
("2", 1, 2, "2", 2),
("2", 2, 4, "2", 4),
("3", 1, 2, "3", 2),
("3", 2, 4, "3", 4),
("0", 1, 2, "0", 2),
("3", 0, 0, "3", 0)
).toDF("a", "b", "d_b", "p_a", "p_b")
val df2 = base.withColumns(
Map(
"d_b" -> base("b") * 2,
"p_a" -> pyarrowTestUDF(base("a")),
"p_b" -> pyarrowTestUDF(base("b") * 2)))
checkAnswer(df2, expected)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, FileScan}
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.joins.BuildSideRelation
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.python.ArrowEvalPythonExec
import org.apache.spark.sql.hive.HiveTableScanExecTransformer
import org.apache.spark.sql.types.{LongType, NullType, StructType}
import org.apache.spark.sql.vectorized.ColumnarBatch
Expand Down Expand Up @@ -738,5 +739,7 @@ trait SparkPlanExecApi {

def genPostProjectForGenerate(generate: GenerateExec): SparkPlan

def genPreProjectForArrowEvalPythonExec(arrowEvalPythonExec: ArrowEvalPythonExec): SparkPlan

def maybeCollapseTakeOrderedAndProject(plan: SparkPlan): SparkPlan = plan
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Partial}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, TypedAggregateExpression}
import org.apache.spark.sql.execution.python.ArrowEvalPythonExec
import org.apache.spark.sql.execution.window.{WindowExec, WindowGroupLimitExecShim}

import scala.collection.mutable
Expand Down Expand Up @@ -226,6 +227,10 @@ object PullOutPreProject extends RewriteSingleNode with PullOutProjectHelper {
case generate: GenerateExec =>
BackendsApiManager.getSparkPlanExecApiInstance.genPreProjectForGenerate(generate)

case arrowEvalPythonExec: ArrowEvalPythonExec =>
BackendsApiManager.getSparkPlanExecApiInstance.genPreProjectForArrowEvalPythonExec(
arrowEvalPythonExec)

case _ => plan
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
import org.apache.spark.sql.execution.joins.BaseJoinExec
import org.apache.spark.sql.execution.python.ArrowEvalPythonExec
import org.apache.spark.sql.execution.window.WindowExec

case class RewrittenNodeWall(originalChild: SparkPlan) extends LeafExecNode {
Expand Down Expand Up @@ -60,6 +61,7 @@ class RewriteSparkPlanRulesManager private (rewriteRules: Seq[RewriteSingleNode]
case _: ExpandExec => true
case _: GenerateExec => true
case plan if SparkShimLoader.getSparkShims.isWindowGroupLimitExec(plan) => true
case _: ArrowEvalPythonExec => true
case _ => false
}
}
Expand Down

0 comments on commit b40e1d5

Please sign in to comment.