diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala index 465041621a61..01fc9faef2d6 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala @@ -860,4 +860,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi { override def genPreProjectForGenerate(generate: GenerateExec): SparkPlan = generate override def genPostProjectForGenerate(generate: GenerateExec): SparkPlan = generate + + override def genPreProjectForArrowEvalPythonExec( + arrowEvalPythonExec: ArrowEvalPythonExec): SparkPlan = arrowEvalPythonExec } diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala index 772f1cfb2422..0914c1c918ef 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala @@ -31,7 +31,7 @@ import org.apache.gluten.substrait.expression.{ExpressionBuilder, ExpressionNode import org.apache.gluten.vectorized.{ColumnarBatchSerializer, ColumnarBatchSerializeResult} import org.apache.spark.{ShuffleDependency, SparkException} -import org.apache.spark.api.python.ColumnarArrowEvalPythonExec +import org.apache.spark.api.python.{ColumnarArrowEvalPythonExec, PullOutArrowEvalPythonPreProjectHelper} import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{GenShuffleWriterParameters, GlutenShuffleWriterWrapper} @@ -53,6 +53,7 @@ import org.apache.spark.sql.execution.datasources.FileFormat import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.{BuildSideRelation, HashedRelationBroadcastMode} import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.execution.python.ArrowEvalPythonExec import org.apache.spark.sql.execution.utils.ExecUtil import org.apache.spark.sql.expression.{UDFExpression, UDFResolver, UserDefinedAggregateFunction} import org.apache.spark.sql.internal.SQLConf @@ -835,6 +836,11 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { PullOutGenerateProjectHelper.pullOutPostProject(generate) } + override def genPreProjectForArrowEvalPythonExec( + arrowEvalPythonExec: ArrowEvalPythonExec): SparkPlan = { + PullOutArrowEvalPythonPreProjectHelper.pullOutPreProject(arrowEvalPythonExec) + } + override def maybeCollapseTakeOrderedAndProject(plan: SparkPlan): SparkPlan = { // This to-top-n optimization assumes exchange operators were already placed in input plan. plan.transformUp { diff --git a/backends-velox/src/main/scala/org/apache/gluten/execution/python/ColumnarArrowEvalPythonExec.scala b/backends-velox/src/main/scala/org/apache/gluten/execution/python/ColumnarArrowEvalPythonExec.scala index 77ef1c6422b2..dffe668d59b5 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/execution/python/ColumnarArrowEvalPythonExec.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/execution/python/ColumnarArrowEvalPythonExec.scala @@ -17,17 +17,18 @@ package org.apache.spark.api.python import org.apache.gluten.columnarbatch.ColumnarBatches +import org.apache.gluten.exception.GlutenException import org.apache.gluten.extension.GlutenPlan import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators -import org.apache.gluten.utils.Iterators +import org.apache.gluten.utils.{Iterators, PullOutProjectHelper} import org.apache.gluten.vectorized.ArrowWritableColumnVector import org.apache.spark.{ContextAwareIterator, SparkEnv, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.python.{BasePythonRunnerShim, EvalPythonExec, PythonUDFRunner} +import org.apache.spark.sql.execution.{ProjectExec, SparkPlan} +import org.apache.spark.sql.execution.python.{ArrowEvalPythonExec, BasePythonRunnerShim, EvalPythonExec, PythonUDFRunner} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructField, StructType} import org.apache.spark.sql.utils.{SparkArrowUtil, SparkSchemaUtil, SparkVectorUtil} @@ -41,6 +42,7 @@ import java.io.{DataInputStream, DataOutputStream} import java.net.Socket import java.util.concurrent.atomic.AtomicBoolean +import scala.collection.{mutable, Seq} import scala.collection.mutable.ArrayBuffer class ColumnarArrowPythonRunner( @@ -207,7 +209,6 @@ case class ColumnarArrowEvalPythonExec( extends EvalPythonExec with GlutenPlan { override def supportsColumnar: Boolean = true - // TODO: add additional projection support by pre-project // FIXME: incorrect metrics updater override protected def evaluate( @@ -221,6 +222,7 @@ case class ColumnarArrowEvalPythonExec( } private val sessionLocalTimeZone = conf.sessionLocalTimeZone + private def getPythonRunnerConfMap(conf: SQLConf): Map[String, String] = { val timeZoneConf = Seq(SQLConf.SESSION_LOCAL_TIMEZONE.key -> conf.sessionLocalTimeZone) val pandasColsByName = Seq( @@ -231,6 +233,7 @@ case class ColumnarArrowEvalPythonExec( conf.arrowSafeTypeConversion.toString) Map(timeZoneConf ++ pandasColsByName ++ arrowSafeTypeCheck: _*) } + private val pythonRunnerConf = getPythonRunnerConfMap(conf) protected def evaluateColumnar( @@ -279,9 +282,11 @@ case class ColumnarArrowEvalPythonExec( iter => val context = TaskContext.get() val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip - // flatten all the arguments + // We only write the referred cols by UDFs to python worker. So we need + // get corresponding offsets val allInputs = new ArrayBuffer[Expression] val dataTypes = new ArrayBuffer[DataType] + val originalOffsets = new ArrayBuffer[Int] val argOffsets = inputs.map { input => input.map { @@ -289,6 +294,18 @@ case class ColumnarArrowEvalPythonExec( if (allInputs.exists(_.semanticEquals(e))) { allInputs.indexWhere(_.semanticEquals(e)) } else { + if (!e.isInstanceOf[AttributeReference]) { + throw new GlutenException( + "ColumnarArrowEvalPythonExec should only has [AttributeReference] inputs.") + } + var offset: Int = -1 + offset = child.output.indexWhere( + _.exprId.equals(e.asInstanceOf[AttributeReference].exprId)) + if (offset == -1) { + throw new GlutenException( + "ColumnarArrowEvalPythonExec can't find referred input col.") + } + originalOffsets += offset allInputs += e dataTypes += e.dataType allInputs.length - 1 @@ -299,15 +316,20 @@ case class ColumnarArrowEvalPythonExec( case (dt, i) => StructField(s"_$i", dt) }.toSeq) + val contextAwareIterator = new ContextAwareIterator(context, iter) val inputCbCache = new ArrayBuffer[ColumnarBatch]() val inputBatchIter = contextAwareIterator.map { inputCb => ColumnarBatches.ensureLoaded(ArrowBufferAllocators.contextInstance, inputCb) - // 0. cache input for later merge ColumnarBatches.retain(inputCb) + // 0. cache input for later merge inputCbCache += inputCb - inputCb + var colsForEval = new ArrayBuffer[ColumnVector]() + for (i <- originalOffsets) { + colsForEval += inputCb.column(i) + } + new ColumnarBatch(colsForEval.toArray, inputCb.numRows()) } val outputColumnarBatchIterator = @@ -335,6 +357,65 @@ case class ColumnarArrowEvalPythonExec( .create() } } + override protected def withNewChildInternal(newChild: SparkPlan): ColumnarArrowEvalPythonExec = copy(udfs, resultAttrs, newChild) } + +object PullOutArrowEvalPythonPreProjectHelper extends PullOutProjectHelper { + private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = { + udf.children match { + case Seq(u: PythonUDF) => + val (chained, children) = collectFunctions(u) + (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children) + case children => + (ChainedPythonFunctions(Seq(udf.func)), udf.children) + } + } + + private def rewriteUDF( + udf: PythonUDF, + expressionMap: mutable.HashMap[Expression, NamedExpression]): PythonUDF = { + udf.children match { + case Seq(u: PythonUDF) => + udf + .withNewChildren(udf.children.toIndexedSeq.map { + func => rewriteUDF(func.asInstanceOf[PythonUDF], expressionMap) + }) + .asInstanceOf[PythonUDF] + case children => + val newUDFChildren = udf.children.map { + case literal: Literal => literal + case other => replaceExpressionWithAttribute(other, expressionMap) + } + udf.withNewChildren(newUDFChildren).asInstanceOf[PythonUDF] + } + } + + def pullOutPreProject(arrowEvalPythonExec: ArrowEvalPythonExec): SparkPlan = { + // pull out preproject + val (_, inputs) = arrowEvalPythonExec.udfs.map(collectFunctions).unzip + val expressionMap = new mutable.HashMap[Expression, NamedExpression]() + // flatten all the arguments + val allInputs = new ArrayBuffer[Expression] + for (input <- inputs) { + input.map { + e => + if (!allInputs.exists(_.semanticEquals(e))) { + allInputs += e + replaceExpressionWithAttribute(e, expressionMap) + } + } + } + if (!expressionMap.isEmpty) { + // Need preproject. + val preProject = ProjectExec( + eliminateProjectList(arrowEvalPythonExec.child.outputSet, expressionMap.values.toSeq), + arrowEvalPythonExec.child) + val newUDFs = arrowEvalPythonExec.udfs.map(f => rewriteUDF(f, expressionMap)) + arrowEvalPythonExec.copy(udfs = newUDFs, child = preProject) + } else { + arrowEvalPythonExec + } + } +} diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/python/ArrowEvalPythonExecSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/python/ArrowEvalPythonExecSuite.scala index 2193448b4d22..0674b06e4ecd 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/python/ArrowEvalPythonExecSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/python/ArrowEvalPythonExecSuite.scala @@ -39,7 +39,7 @@ class ArrowEvalPythonExecSuite extends WholeStageTransformerSuite { .set("spark.executor.cores", "1") } - test("arrow_udf test") { + test("arrow_udf test: without projection") { lazy val base = Seq(("1", 1), ("1", 2), ("2", 1), ("2", 2), ("3", 1), ("3", 2), ("0", 1), ("3", 0)) .toDF("a", "b") @@ -58,4 +58,46 @@ class ArrowEvalPythonExecSuite extends WholeStageTransformerSuite { checkSparkOperatorMatch[ColumnarArrowEvalPythonExec](df2) checkAnswer(df2, expected) } + + testWithSpecifiedSparkVersion("arrow_udf test: with unrelated projection", Some("3.3")) { + lazy val base = + Seq(("1", 1), ("1", 2), ("2", 1), ("2", 2), ("3", 1), ("3", 2), ("0", 1), ("3", 0)) + .toDF("a", "b") + lazy val expected = Seq( + ("1", 1, "1", 2), + ("1", 2, "1", 4), + ("2", 1, "2", 2), + ("2", 2, "2", 4), + ("3", 1, "3", 2), + ("3", 2, "3", 4), + ("0", 1, "0", 2), + ("3", 0, "3", 0) + ).toDF("a", "b", "p_a", "d_b") + + val df2 = base.withColumns(Map("p_a" -> pyarrowTestUDF(base("a")), "d_b" -> base("b") * 2)) + checkSparkOperatorMatch[ColumnarArrowEvalPythonExec](df2) + checkAnswer(df2, expected) + } + + testWithSpecifiedSparkVersion("arrow_udf test: with preprojection", Some("3.3")) { + lazy val base = + Seq(("1", 1), ("1", 2), ("2", 1), ("2", 2), ("3", 1), ("3", 2), ("0", 1), ("3", 0)) + .toDF("a", "b") + lazy val expected = Seq( + ("1", 1, 2, "1", 2), + ("1", 2, 4, "1", 4), + ("2", 1, 2, "2", 2), + ("2", 2, 4, "2", 4), + ("3", 1, 2, "3", 2), + ("3", 2, 4, "3", 4), + ("0", 1, 2, "0", 2), + ("3", 0, 0, "3", 0) + ).toDF("a", "b", "d_b", "p_a", "p_b") + val df2 = base.withColumns( + Map( + "d_b" -> base("b") * 2, + "p_a" -> pyarrowTestUDF(base("a")), + "p_b" -> pyarrowTestUDF(base("b") * 2))) + checkAnswer(df2, expected) + } } diff --git a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala index fb2fd961b481..9b797bc18561 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala @@ -43,6 +43,7 @@ import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, FileScan} import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.joins.BuildSideRelation import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.execution.python.ArrowEvalPythonExec import org.apache.spark.sql.hive.HiveTableScanExecTransformer import org.apache.spark.sql.types.{LongType, NullType, StructType} import org.apache.spark.sql.vectorized.ColumnarBatch @@ -738,5 +739,7 @@ trait SparkPlanExecApi { def genPostProjectForGenerate(generate: GenerateExec): SparkPlan + def genPreProjectForArrowEvalPythonExec(arrowEvalPythonExec: ArrowEvalPythonExec): SparkPlan + def maybeCollapseTakeOrderedAndProject(plan: SparkPlan): SparkPlan = plan } diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/PullOutPreProject.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/PullOutPreProject.scala index 64d4f273622c..50dc55423605 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/PullOutPreProject.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/PullOutPreProject.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Partial} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, TypedAggregateExpression} +import org.apache.spark.sql.execution.python.ArrowEvalPythonExec import org.apache.spark.sql.execution.window.{WindowExec, WindowGroupLimitExecShim} import scala.collection.mutable @@ -226,6 +227,10 @@ object PullOutPreProject extends RewriteSingleNode with PullOutProjectHelper { case generate: GenerateExec => BackendsApiManager.getSparkPlanExecApiInstance.genPreProjectForGenerate(generate) + case arrowEvalPythonExec: ArrowEvalPythonExec => + BackendsApiManager.getSparkPlanExecApiInstance.genPreProjectForArrowEvalPythonExec( + arrowEvalPythonExec) + case _ => plan } } diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteSparkPlanRulesManager.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteSparkPlanRulesManager.scala index 5fd728eca65a..ac663314bead 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteSparkPlanRulesManager.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteSparkPlanRulesManager.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.aggregate.BaseAggregateExec import org.apache.spark.sql.execution.joins.BaseJoinExec +import org.apache.spark.sql.execution.python.ArrowEvalPythonExec import org.apache.spark.sql.execution.window.WindowExec case class RewrittenNodeWall(originalChild: SparkPlan) extends LeafExecNode { @@ -60,6 +61,7 @@ class RewriteSparkPlanRulesManager private (rewriteRules: Seq[RewriteSingleNode] case _: ExpandExec => true case _: GenerateExec => true case plan if SparkShimLoader.getSparkShims.isWindowGroupLimitExec(plan) => true + case _: ArrowEvalPythonExec => true case _ => false } }