Skip to content

Commit

Permalink
[VL] support spark nanvl function (#4446)
Browse files Browse the repository at this point in the history
[VL] support spark nanvl function.
  • Loading branch information
zhli1142015 committed Jan 19, 2024
1 parent d059206 commit 55e21df
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.{AggregateFunctionRewriteRule, FlushableHas
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, CreateNamedStruct, ElementAt, Expression, ExpressionInfo, GetArrayItem, GetMapValue, GetStructField, Literal, NamedExpression, StringSplit, StringTrim}
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, CreateNamedStruct, ElementAt, Expression, ExpressionInfo, GetArrayItem, GetMapValue, GetStructField, If, IsNaN, Literal, NamedExpression, NaNvl, StringSplit, StringTrim}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, HLLAdapter}
import org.apache.spark.sql.catalyst.optimizer.BuildSide
import org.apache.spark.sql.catalyst.plans.JoinType
Expand Down Expand Up @@ -112,6 +112,23 @@ class SparkPlanExecApiImpl extends SparkPlanExecApi {
new IfThenNode(Lists.newArrayList(lessThanFuncNode), Lists.newArrayList(nullNode), resultNode)
}

/** Transform NaNvl to Substrait. */
override def genNaNvlTransformer(
substraitExprName: String,
left: ExpressionTransformer,
right: ExpressionTransformer,
original: NaNvl): ExpressionTransformer = {
val condExpr = IsNaN(original.left)
val condFuncName = ExpressionMappings.expressionsMap(classOf[IsNaN])
val newExpr = If(condExpr, original.right, original.left)
IfTransformer(
GenericExpressionTransformer(condFuncName, Seq(left), condExpr),
right,
left,
newExpr
)
}

/**
* * Plans.
*/
Expand Down Expand Up @@ -518,7 +535,8 @@ class SparkPlanExecApiImpl extends SparkPlanExecApi {
override def extraExpressionMappings: Seq[Sig] = {
Seq(
Sig[HLLAdapter](ExpressionNames.APPROX_DISTINCT),
Sig[UDFExpression](ExpressionNames.UDF_PLACEHOLDER))
Sig[UDFExpression](ExpressionNames.UDF_PLACEHOLDER),
Sig[NaNvl](ExpressionNames.NANVL))
}

override def genInjectedFunctions()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -372,4 +372,15 @@ class VeloxFunctionsValidateSuite extends VeloxWholeStageTransformerSuite {
checkOperatorMatch[ProjectExecTransformer]
}
}

test("Test nanvl function") {
runQueryAndCompare("""SELECT nanvl(cast('nan' as float), 1f),
| nanvl(l_orderkey, cast('null' as double)),
| nanvl(cast('null' as double), l_orderkey),
| nanvl(l_orderkey, l_orderkey / 0.0d),
| nanvl(cast('nan' as float), l_orderkey)
| from lineitem limit 1""".stripMargin) {
checkOperatorMatch[ProjectExecTransformer]
}
}
}
2 changes: 1 addition & 1 deletion docs/velox-backend-support-progress.md
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ Gluten supports 199 functions. (Draw to right to see all data types)
| least | least | least | S | | | | | | S | S | S | S | S | | | | | | | | | |
| md5 | md5 | | S | | | S | | | | | | | | | | | | | | | | |
| monotonically_increasing_id | | | | | | | | | | | | | | | | | | | | | | |
| nanvl | | | | | | | | | | | | | | | | | | | | | | |
| nanvl | | | S | | | | | | | | | | | | | | | | | | | |
| nvl | | | | | | | | | | | | | | | | | | | | | | |
| nvl2 | | | | | | | | | | | | | | | | | | | | | | |
| raise_error | | | | | | | | | | | | | | | | | | | | | | |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,15 @@ trait SparkPlanExecApi {
PosExplodeTransformer(substraitExprName, child, original, attributeSeq)
}

/** Transform NaNvl to Substrait. */
def genNaNvlTransformer(
substraitExprName: String,
left: ExpressionTransformer,
right: ExpressionTransformer,
original: NaNvl): ExpressionTransformer = {
throw new UnsupportedOperationException("NaNvl is not supported")
}

/**
* Generate ShuffleDependency for ColumnarShuffleExchangeExec.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,13 @@ object ExpressionConverter extends SQLConfHelper with Logging {
rightChild,
resultType,
b)
case n: NaNvl =>
BackendsApiManager.getSparkPlanExecApiInstance.genNaNvlTransformer(
substraitExprName,
replaceWithExpressionTransformerInternal(n.left, attributeSeq, expressionsMap),
replaceWithExpressionTransformerInternal(n.right, attributeSeq, expressionsMap),
n
)
case e: Transformable =>
val childrenTransformers =
e.children.map(replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ object ExpressionNames {
final val IS_NULL = "is_null"
final val NOT = "not"
final val IS_NAN = "isnan"
final val NANVL = "nanvl"

// SparkSQL String functions
final val ASCII = "ascii"
Expand Down

0 comments on commit 55e21df

Please sign in to comment.