Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-43939][CONNECT][PYTHON] Add try_* functions to Scala and Python #41653

Closed
wants to merge 8 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -1807,6 +1807,75 @@ object functions {
*/
def sqrt(colName: String): Column = sqrt(Column(colName))

/**
* Returns the sum of `left` and `right` and the result is null on overflow. The acceptable
* input types are the same with the `+` operator.
*
* @note
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it should be supported naturally in Connect

* Only Numeric type is supported in this function, while `try_add` in SQL supports Numeric,
* DATE, TIMESTAMP, and INTERVAL.
*
* @group math_funcs
* @since 3.5.0
*/
def try_add(left: Column, right: Column): Column = Column.fn("try_add", left, right)

/**
* Returns the mean calculated from values of a group and the result is null on overflow.
*
* @group math_funcs
* @since 3.5.0
*/
def try_avg(e: Column): Column = Column.fn("try_avg", e)

/**
* Returns `dividend``/``divisor`. It always performs floating point division. Its result is
* always null if `divisor` is 0.
*
* @note
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in this PR, let's use call_udf for better parity

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👌🏻

* The `dividend` must be a numeric, `divisor` must be a numeric in this function. While the
* `dividend` can be a numeric or an interval, `divisor` must be a numeric in SQL function
* `try_divide`.
*
* @group math_funcs
* @since 3.5.0
*/
def try_divide(left: Column, right: Column): Column = Column.fn("try_divide", left, right)

/**
* Returns `left``*``right` and the result is null on overflow. The acceptable input types are the
* same with the `*` operator.
*
* @note
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

* Only Numeric type is supported in this function, while `try_multiply` in SQL supports
* Numeric and INTERVAL.
*
* @group math_funcs
* @since 3.5.0
*/
def try_multiply(left: Column, right: Column): Column = Column.fn("try_multiply", left, right)

/**
* Returns `left`-`right` and the result is null on overflow. The acceptable input types are the
* same with the `-` operator.
*
* @note
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

* Only Numeric type is supported in this function, while `try_subtract` in SQL supports
* Numeric, DATE, TIMESTAMP, and INTERVAL.
*
* @group math_funcs
* @since 3.5.0
*/
def try_subtract(left: Column, right: Column): Column = Column.fn("try_subtract", left, right)

/**
* Returns the sum calculated from values of a group and the result is null on overflow.
*
* @group math_funcs
* @since 3.5.0
*/
def try_sum(e: Column): Column = Column.fn("try_sum", e)

/**
* Creates a new struct column. If the input column is a column in a `DataFrame`, or a derived
* column expression that is named (i.e. aliased), its name would be retained as the
Expand Down Expand Up @@ -3971,6 +4040,34 @@ object functions {
def startswith(str: Column, prefix: Column): Column =
Column.fn("startswith", str, prefix)

/**
* This is a special version of `to_binary` that performs the same operation, but returns a NULL
* value instead of raising an error if the conversion cannot be performed.
*
* @group string_funcs
* @since 3.5.0
*/
def try_to_binary(e: Column, f: Column): Column = Column.fn("try_to_binary", e, f)

/**
* This is a special version of `to_binary` that performs the same operation, but returns a NULL
* value instead of raising an error if the conversion cannot be performed.
*
* @group string_funcs
* @since 3.5.0
*/
def try_to_binary(e: Column): Column = Column.fn("try_to_binary", e)

/**
* Convert string 'e' to a number based on the string format `format`. Returns NULL if the
* string 'e' does not match the expected format. The format follows the same semantics as the
* to_number function.
*
* @group string_funcs
* @since 3.5.0
*/
def try_to_number(e: Column, format: Column): Column = Column.fn("try_to_number", e, format)

//////////////////////////////////////////////////////////////////////////////////////////////
// DateTime functions
//////////////////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -4474,6 +4571,27 @@ object functions {
*/
def to_timestamp(s: Column, fmt: String): Column = Column.fn("to_timestamp", s, lit(fmt))

/**
* Parses the `s` with the `format` to a timestamp. The function always returns null on an
* invalid input with/without ANSI SQL mode enabled. The result data type is consistent with the
* value of configuration `spark.sql.timestampType`.
*
* @group datetime_funcs
* @since 3.5.0
*/
def try_to_timestamp(s: Column, format: Column): Column =
Column.fn("try_to_timestamp", s, format)

/**
* Parses the `s` expression to a timestamp. The function always returns null on an invalid
* input with/without ANSI SQL mode enabled. It follows casting rules to a timestamp. The result
* data type is consistent with the value of configuration `spark.sql.timestampType`.
*
* @group datetime_funcs
* @since 3.5.0
*/
def try_to_timestamp(s: Column): Column = Column.fn("try_to_timestamp", s)

/**
* Converts the column into `DateType` by casting rules to `DateType`.
*
Expand Down Expand Up @@ -5034,6 +5152,20 @@ object functions {
*/
def element_at(column: Column, value: Any): Column = Column.fn("element_at", column, lit(value))

/**
* (array, index) - Returns element of array at given (1-based) index. If Index is 0, Spark will
* throw an error. If index < 0, accesses elements from the last to the first. The function
* always returns NULL if the index exceeds the length of the array.
*
* (map, key) - Returns value for given key. The function always returns NULL if the key is not
* contained in the map.
*
* @group map_funcs
* @since 3.5.0
*/
def try_element_at(column: Column, value: Column): Column =
Column.fn("try_element_at", column, value)

/**
* Returns element of array at given (0-based) index. If the index points outside of the array
* boundaries, then this function returns NULL.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1213,6 +1213,58 @@ class PlanGenerationTestSuite
fn.sqrt("b")
}

functionTest("try_add") {
fn.try_add(fn.col("a"), fn.col("a"))
}

functionTest("try_avg") {
fn.try_avg(fn.col("a"))
}

functionTest("try_divide") {
fn.try_divide(fn.col("a"), fn.col("a"))
}

functionTest("try_multiply") {
fn.try_multiply(fn.col("a"), fn.col("a"))
}

functionTest("try_subtract") {
fn.try_subtract(fn.col("a"), fn.col("a"))
}

functionTest("try_sum") {
fn.try_sum(fn.col("a"))
}

functionTest("try_to_timestamp") {
fn.try_to_timestamp(fn.col("g"), fn.col("g"))
}

functionTest("try_to_timestamp without format") {
fn.try_to_timestamp(fn.col("g"))
}

functionTest("try_to_binary") {
fn.try_to_binary(fn.col("g"), lit("format"))
}

functionTest("try_to_binary without format") {
fn.try_to_binary(fn.col("g"))
}

functionTest("try_to_number") {
fn.try_to_number(fn.col("g"), lit("99,999"))
}

functionTest("try_element_at array") {
fn.try_element_at(fn.col("e"), fn.col("a"))
}

functionTest("try_element_at map") {
fn.try_element_at(fn.col("f"), fn.col("g"))
}

functionTest("struct") {
fn.struct("a", "d")
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Project [(a#0 + a#0) AS try_add(a, a)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Aggregate [try_avg(a#0) AS try_avg(a)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Project [(cast(a#0 as double) / cast(a#0 as double)) AS try_divide(a, a)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Project [element_at(e#0, a#0, None, false) AS try_element_at(e, a)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Project [element_at(f#0, g#0, None, false) AS try_element_at(f, g)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Project [(a#0 * a#0) AS try_multiply(a, a)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Project [(a#0 - a#0) AS try_subtract(a, a)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Aggregate [try_sum(a#0) AS try_sum(a)#0L]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Project [tryeval(null) AS try_to_binary(g, format)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Project [tryeval(unhex(g#0, true)) AS try_to_binary(g)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Project [try_to_number(g#0, 99,999) AS try_to_number(g, 99,999)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Project [gettimestamp(g#0, g#0, TimestampType, Some(America/Los_Angeles), false) AS try_to_timestamp(g, g)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Project [cast(g#0 as timestamp) AS try_to_timestamp(g)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
{
"common": {
"planId": "1"
},
"project": {
"input": {
"common": {
"planId": "0"
},
"localRelation": {
"schema": "struct\u003cid:bigint,a:int,b:double,d:struct\u003cid:bigint,a:int,b:double\u003e,e:array\u003cint\u003e,f:map\u003cstring,struct\u003cid:bigint,a:int,b:double\u003e\u003e,g:string\u003e"
}
},
"expressions": [{
"unresolvedFunction": {
"functionName": "try_add",
"arguments": [{
"unresolvedAttribute": {
"unparsedIdentifier": "a"
}
}, {
"unresolvedAttribute": {
"unparsedIdentifier": "a"
}
}]
}
}]
}
}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
{
"common": {
"planId": "1"
},
"project": {
"input": {
"common": {
"planId": "0"
},
"localRelation": {
"schema": "struct\u003cid:bigint,a:int,b:double,d:struct\u003cid:bigint,a:int,b:double\u003e,e:array\u003cint\u003e,f:map\u003cstring,struct\u003cid:bigint,a:int,b:double\u003e\u003e,g:string\u003e"
}
},
"expressions": [{
"unresolvedFunction": {
"functionName": "try_avg",
"arguments": [{
"unresolvedAttribute": {
"unparsedIdentifier": "a"
}
}]
}
}]
}
}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
{
"common": {
"planId": "1"
},
"project": {
"input": {
"common": {
"planId": "0"
},
"localRelation": {
"schema": "struct\u003cid:bigint,a:int,b:double,d:struct\u003cid:bigint,a:int,b:double\u003e,e:array\u003cint\u003e,f:map\u003cstring,struct\u003cid:bigint,a:int,b:double\u003e\u003e,g:string\u003e"
}
},
"expressions": [{
"unresolvedFunction": {
"functionName": "try_divide",
"arguments": [{
"unresolvedAttribute": {
"unparsedIdentifier": "a"
}
}, {
"unresolvedAttribute": {
"unparsedIdentifier": "a"
}
}]
}
}]
}
}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
{
"common": {
"planId": "1"
},
"project": {
"input": {
"common": {
"planId": "0"
},
"localRelation": {
"schema": "struct\u003cid:bigint,a:int,b:double,d:struct\u003cid:bigint,a:int,b:double\u003e,e:array\u003cint\u003e,f:map\u003cstring,struct\u003cid:bigint,a:int,b:double\u003e\u003e,g:string\u003e"
}
},
"expressions": [{
"unresolvedFunction": {
"functionName": "try_element_at",
"arguments": [{
"unresolvedAttribute": {
"unparsedIdentifier": "e"
}
}, {
"unresolvedAttribute": {
"unparsedIdentifier": "a"
}
}]
}
}]
}
}
Binary file not shown.