Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GLUTEN-5696] Add preprojection support for ArrowEvalPythonExec #5697

Merged
merged 2 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.gluten.substrait.expression.{ExpressionBuilder, ExpressionNode
import org.apache.gluten.vectorized.{ColumnarBatchSerializer, ColumnarBatchSerializeResult}

import org.apache.spark.{ShuffleDependency, SparkException}
import org.apache.spark.api.python.ColumnarArrowEvalPythonExec
import org.apache.spark.api.python.{ColumnarArrowEvalPythonExec, PullOutArrowEvalPythonPreProjectHelper}
import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{GenShuffleWriterParameters, GlutenShuffleWriterWrapper}
Expand All @@ -53,6 +53,7 @@ import org.apache.spark.sql.execution.datasources.FileFormat
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BuildSideRelation, HashedRelationBroadcastMode}
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.python.ArrowEvalPythonExec
import org.apache.spark.sql.execution.utils.ExecUtil
import org.apache.spark.sql.expression.{UDFExpression, UDFResolver, UserDefinedAggregateFunction}
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -846,6 +847,11 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
PullOutGenerateProjectHelper.pullOutPostProject(generate)
}

override def genPreProjectForArrowEvalPythonExec(
arrowEvalPythonExec: ArrowEvalPythonExec): SparkPlan = {
PullOutArrowEvalPythonPreProjectHelper.pullOutPreProject(arrowEvalPythonExec)
}

override def maybeCollapseTakeOrderedAndProject(plan: SparkPlan): SparkPlan = {
// This to-top-n optimization assumes exchange operators were already placed in input plan.
plan.transformUp {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,18 @@
package org.apache.spark.api.python

import org.apache.gluten.columnarbatch.ColumnarBatches
import org.apache.gluten.exception.GlutenException
import org.apache.gluten.extension.GlutenPlan
import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators
import org.apache.gluten.utils.Iterators
import org.apache.gluten.utils.{Iterators, PullOutProjectHelper}
import org.apache.gluten.vectorized.ArrowWritableColumnVector

import org.apache.spark.{ContextAwareIterator, SparkEnv, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.python.{BasePythonRunnerShim, EvalPythonExec, PythonUDFRunner}
import org.apache.spark.sql.execution.{ProjectExec, SparkPlan}
import org.apache.spark.sql.execution.python.{ArrowEvalPythonExec, BasePythonRunnerShim, EvalPythonExec, PythonUDFRunner}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, StructField, StructType}
import org.apache.spark.sql.utils.{SparkArrowUtil, SparkSchemaUtil, SparkVectorUtil}
Expand All @@ -41,6 +42,7 @@ import java.io.{DataInputStream, DataOutputStream}
import java.net.Socket
import java.util.concurrent.atomic.AtomicBoolean

import scala.collection.{mutable, Seq}
import scala.collection.mutable.ArrayBuffer

class ColumnarArrowPythonRunner(
Expand Down Expand Up @@ -207,7 +209,6 @@ case class ColumnarArrowEvalPythonExec(
extends EvalPythonExec
with GlutenPlan {
override def supportsColumnar: Boolean = true
// TODO: add additional projection support by pre-project
// FIXME: incorrect metrics updater

override protected def evaluate(
Expand All @@ -221,6 +222,7 @@ case class ColumnarArrowEvalPythonExec(
}

private val sessionLocalTimeZone = conf.sessionLocalTimeZone

private def getPythonRunnerConfMap(conf: SQLConf): Map[String, String] = {
val timeZoneConf = Seq(SQLConf.SESSION_LOCAL_TIMEZONE.key -> conf.sessionLocalTimeZone)
val pandasColsByName = Seq(
Expand All @@ -231,6 +233,7 @@ case class ColumnarArrowEvalPythonExec(
conf.arrowSafeTypeConversion.toString)
Map(timeZoneConf ++ pandasColsByName ++ arrowSafeTypeCheck: _*)
}

private val pythonRunnerConf = getPythonRunnerConfMap(conf)

protected def evaluateColumnar(
Expand Down Expand Up @@ -279,16 +282,29 @@ case class ColumnarArrowEvalPythonExec(
iter =>
val context = TaskContext.get()
val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip
// flatten all the arguments
// We only write the referred cols by UDFs to python worker. So we need
// get corresponding offsets
val allInputs = new ArrayBuffer[Expression]
val dataTypes = new ArrayBuffer[DataType]
val originalOffsets = new ArrayBuffer[Int]
val argOffsets = inputs.map {
input =>
input.map {
e =>
if (allInputs.exists(_.semanticEquals(e))) {
if (!e.isInstanceOf[AttributeReference]) {
throw new GlutenException(
"ColumnarArrowEvalPythonExec should only has [AttributeReference] inputs.")
} else if (allInputs.exists(_.semanticEquals(e))) {
allInputs.indexWhere(_.semanticEquals(e))
} else {
var offset: Int = -1
offset = child.output.indexWhere(
_.exprId.equals(e.asInstanceOf[AttributeReference].exprId))
if (offset == -1) {
throw new GlutenException(
"ColumnarArrowEvalPythonExec can't find referred input col.")
}
originalOffsets += offset
allInputs += e
dataTypes += e.dataType
allInputs.length - 1
Expand All @@ -299,15 +315,21 @@ case class ColumnarArrowEvalPythonExec(
case (dt, i) =>
StructField(s"_$i", dt)
}.toSeq)

val contextAwareIterator = new ContextAwareIterator(context, iter)
val inputCbCache = new ArrayBuffer[ColumnarBatch]()
val inputBatchIter = contextAwareIterator.map {
inputCb =>
ColumnarBatches.ensureLoaded(ArrowBufferAllocators.contextInstance, inputCb)
// 0. cache input for later merge
ColumnarBatches.retain(inputCb)
// 0. cache input for later merge
inputCbCache += inputCb
inputCb
// We only need to pass the referred cols data to python worker for evaluation.
var colsForEval = new ArrayBuffer[ColumnVector]()
for (i <- originalOffsets) {
yma11 marked this conversation as resolved.
Show resolved Hide resolved
colsForEval += inputCb.column(i)
}
new ColumnarBatch(colsForEval.toArray, inputCb.numRows())
}

val outputColumnarBatchIterator =
Expand Down Expand Up @@ -335,6 +357,65 @@ case class ColumnarArrowEvalPythonExec(
.create()
}
}

override protected def withNewChildInternal(newChild: SparkPlan): ColumnarArrowEvalPythonExec =
copy(udfs, resultAttrs, newChild)
}

object PullOutArrowEvalPythonPreProjectHelper extends PullOutProjectHelper {
private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = {
udf.children match {
case Seq(u: PythonUDF) =>
val (chained, children) = collectFunctions(u)
(ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children)
case children =>
(ChainedPythonFunctions(Seq(udf.func)), udf.children)
}
}

private def rewriteUDF(
udf: PythonUDF,
expressionMap: mutable.HashMap[Expression, NamedExpression]): PythonUDF = {
udf.children match {
case Seq(u: PythonUDF) =>
udf
.withNewChildren(udf.children.toIndexedSeq.map {
func => rewriteUDF(func.asInstanceOf[PythonUDF], expressionMap)
})
.asInstanceOf[PythonUDF]
case children =>
val newUDFChildren = udf.children.map {
case literal: Literal => literal
case other => replaceExpressionWithAttribute(other, expressionMap)
}
udf.withNewChildren(newUDFChildren).asInstanceOf[PythonUDF]
}
}

def pullOutPreProject(arrowEvalPythonExec: ArrowEvalPythonExec): SparkPlan = {
// pull out preproject
val (_, inputs) = arrowEvalPythonExec.udfs.map(collectFunctions).unzip
val expressionMap = new mutable.HashMap[Expression, NamedExpression]()
// flatten all the arguments
val allInputs = new ArrayBuffer[Expression]
for (input <- inputs) {
input.map {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input.filter(allInputs.exists(_.semanticEquals(e))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is not to filter the inputs but add the matched expression in allInputs and replace it using faked Attribute.

e =>
if (!allInputs.exists(_.semanticEquals(e))) {
allInputs += e
replaceExpressionWithAttribute(e, expressionMap)
}
}
}
if (!expressionMap.isEmpty) {
// Need preproject.
val preProject = ProjectExec(
eliminateProjectList(arrowEvalPythonExec.child.outputSet, expressionMap.values.toSeq),
arrowEvalPythonExec.child)
val newUDFs = arrowEvalPythonExec.udfs.map(f => rewriteUDF(f, expressionMap))
arrowEvalPythonExec.copy(udfs = newUDFs, child = preProject)
} else {
arrowEvalPythonExec
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class ArrowEvalPythonExecSuite extends WholeStageTransformerSuite {
.set("spark.executor.cores", "1")
}

test("arrow_udf test") {
test("arrow_udf test: without projection") {
lazy val base =
Seq(("1", 1), ("1", 2), ("2", 1), ("2", 2), ("3", 1), ("3", 2), ("0", 1), ("3", 0))
.toDF("a", "b")
Expand All @@ -58,4 +58,45 @@ class ArrowEvalPythonExecSuite extends WholeStageTransformerSuite {
checkSparkOperatorMatch[ColumnarArrowEvalPythonExec](df2)
checkAnswer(df2, expected)
}

test("arrow_udf test: with unrelated projection") {
lazy val base =
Seq(("1", 1), ("1", 2), ("2", 1), ("2", 2), ("3", 1), ("3", 2), ("0", 1), ("3", 0))
.toDF("a", "b")
lazy val expected = Seq(
("1", 1, "1", 2),
("1", 2, "1", 4),
("2", 1, "2", 2),
("2", 2, "2", 4),
("3", 1, "3", 2),
("3", 2, "3", 4),
("0", 1, "0", 2),
("3", 0, "3", 0)
).toDF("a", "b", "p_a", "d_b")

val df = base.withColumn("p_a", pyarrowTestUDF(base("a"))).withColumn("d_b", base("b") * 2)
checkSparkOperatorMatch[ColumnarArrowEvalPythonExec](df)
checkAnswer(df, expected)
}

test("arrow_udf test: with preprojection") {
lazy val base =
Seq(("1", 1), ("1", 2), ("2", 1), ("2", 2), ("3", 1), ("3", 2), ("0", 1), ("3", 0))
.toDF("a", "b")
lazy val expected = Seq(
("1", 1, 2, "1", 2),
("1", 2, 4, "1", 4),
("2", 1, 2, "2", 2),
("2", 2, 4, "2", 4),
("3", 1, 2, "3", 2),
("3", 2, 4, "3", 4),
("0", 1, 2, "0", 2),
("3", 0, 0, "3", 0)
).toDF("a", "b", "d_b", "p_a", "p_b")
val df = base
.withColumn("d_b", base("b") * 2)
.withColumn("p_a", pyarrowTestUDF(base("a")))
.withColumn("p_b", pyarrowTestUDF(base("b") * 2))
checkAnswer(df, expected)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, FileScan}
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.joins.BuildSideRelation
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.python.ArrowEvalPythonExec
import org.apache.spark.sql.hive.HiveTableScanExecTransformer
import org.apache.spark.sql.types.{LongType, NullType, StructType}
import org.apache.spark.sql.vectorized.ColumnarBatch
Expand Down Expand Up @@ -745,6 +746,9 @@ trait SparkPlanExecApi {

def genPostProjectForGenerate(generate: GenerateExec): SparkPlan

def genPreProjectForArrowEvalPythonExec(arrowEvalPythonExec: ArrowEvalPythonExec): SparkPlan =
arrowEvalPythonExec

def maybeCollapseTakeOrderedAndProject(plan: SparkPlan): SparkPlan = plan

def outputNativeColumnarSparkCompatibleData(plan: SparkPlan): Boolean = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Partial}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, TypedAggregateExpression}
import org.apache.spark.sql.execution.python.ArrowEvalPythonExec
import org.apache.spark.sql.execution.window.{WindowExec, WindowGroupLimitExecShim}

import scala.collection.mutable
Expand Down Expand Up @@ -226,6 +227,10 @@ object PullOutPreProject extends RewriteSingleNode with PullOutProjectHelper {
case generate: GenerateExec =>
BackendsApiManager.getSparkPlanExecApiInstance.genPreProjectForGenerate(generate)

case arrowEvalPythonExec: ArrowEvalPythonExec =>
BackendsApiManager.getSparkPlanExecApiInstance.genPreProjectForArrowEvalPythonExec(
arrowEvalPythonExec)

case _ => plan
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
import org.apache.spark.sql.execution.joins.BaseJoinExec
import org.apache.spark.sql.execution.python.ArrowEvalPythonExec
import org.apache.spark.sql.execution.window.WindowExec

case class RewrittenNodeWall(originalChild: SparkPlan) extends LeafExecNode {
Expand Down Expand Up @@ -60,6 +61,7 @@ class RewriteSparkPlanRulesManager private (rewriteRules: Seq[RewriteSingleNode]
case _: ExpandExec => true
case _: GenerateExec => true
case plan if SparkShimLoader.getSparkShims.isWindowGroupLimitExec(plan) => true
case _: ArrowEvalPythonExec => true
case _ => false
}
}
Expand Down