Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VL] Enable array_size Spark function #5539

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ object CHExpressionUtil {
ARRAY_EXCEPT -> DefaultValidator(),
ARRAY_REPEAT -> DefaultValidator(),
ARRAY_REMOVE -> DefaultValidator(),
ARRAY_SIZE -> DefaultValidator(),
DATE_FROM_UNIX_DATE -> DefaultValidator(),
UNIX_DATE -> DefaultValidator(),
MONOTONICALLY_INCREASING_ID -> DefaultValidator(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
*/
package org.apache.gluten.execution

import org.apache.gluten.sql.shims.SparkShimLoader

import org.apache.spark.sql.types._

import java.sql.Timestamp
Expand Down Expand Up @@ -826,4 +828,21 @@ class ScalarFunctionsValidateSuite extends FunctionsValidateTest {
}
}

test("test array_size") {
if (!SparkShimLoader.getSparkVersion.startsWith("3.2")) {
withTempPath {
path =>
Seq[Seq[Integer]](Seq(1, null, 5, 4), Seq(5, -1, 8, 9, -7, 2), Seq.empty, null)
.toDF("value")
.write
.parquet(path.getCanonicalPath)

spark.read.parquet(path.getCanonicalPath).createOrReplaceTempView("array_tbl")

runQueryAndCompare("select array_size(value) as res from array_tbl;") {
checkGlutenOperatorMatch[ProjectExecTransformer]
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ object ExpressionNames {
final val ARRAY_EXCEPT = "array_except"
final val ARRAY_REPEAT = "array_repeat"
final val ARRAY_REMOVE = "array_remove"
final val ARRAY_SIZE = "array_size"
final val FILTER = "filter"
final val FORALL = "forall"
final val EXISTS = "exists"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package org.apache.gluten.sql.shims.spark33

import org.apache.gluten.execution.datasource.GlutenParquetWriterInjects
import org.apache.gluten.expression.{ExpressionNames, Sig}
import org.apache.gluten.expression.ExpressionNames.{CEIL, FLOOR, KNOWN_NULLABLE, TIMESTAMP_ADD}
import org.apache.gluten.expression.ExpressionNames.{ARRAY_SIZE, CEIL, FLOOR, KNOWN_NULLABLE, TIMESTAMP_ADD}
import org.apache.gluten.sql.shims.{ShimDescriptor, SparkShims}

import org.apache.spark._
Expand Down Expand Up @@ -69,7 +69,8 @@ class Spark33Shims extends SparkShims {
Sig[Empty2Null](ExpressionNames.EMPTY2NULL),
Sig[TimestampAdd](TIMESTAMP_ADD),
Sig[RoundFloor](FLOOR),
Sig[RoundCeil](CEIL)
Sig[RoundCeil](CEIL),
Sig[ArraySize](ARRAY_SIZE)
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package org.apache.gluten.sql.shims.spark34

import org.apache.gluten.expression.{ExpressionNames, Sig}
import org.apache.gluten.expression.ExpressionNames.KNOWN_NULLABLE
import org.apache.gluten.expression.ExpressionNames.{ARRAY_SIZE, KNOWN_NULLABLE}
import org.apache.gluten.sql.shims.{ShimDescriptor, SparkShims}

import org.apache.spark._
Expand Down Expand Up @@ -70,7 +70,8 @@ class Spark34Shims extends SparkShims {
Sig[Sec](ExpressionNames.SEC),
Sig[Csc](ExpressionNames.CSC),
Sig[KnownNullable](KNOWN_NULLABLE),
Sig[Empty2Null](ExpressionNames.EMPTY2NULL)
Sig[Empty2Null](ExpressionNames.EMPTY2NULL),
Sig[ArraySize](ARRAY_SIZE)
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.apache.gluten.sql.shims.spark35

import org.apache.gluten.expression.{ExpressionNames, Sig}
import org.apache.gluten.expression.ExpressionNames.ARRAY_SIZE
import org.apache.gluten.sql.shims.{ShimDescriptor, SparkShims}

import org.apache.spark._
Expand Down Expand Up @@ -70,7 +71,9 @@ class Spark35Shims extends SparkShims {
Sig[SplitPart](ExpressionNames.SPLIT_PART),
Sig[Sec](ExpressionNames.SEC),
Sig[Csc](ExpressionNames.CSC),
Sig[Empty2Null](ExpressionNames.EMPTY2NULL))
Sig[Empty2Null](ExpressionNames.EMPTY2NULL),
Sig[ArraySize](ARRAY_SIZE)
)
}

override def aggregateExpressionMappings: Seq[Sig] = {
Expand Down