From 0fc04a642107048aa7411a75e2e7f4ee39ef4922 Mon Sep 17 00:00:00 2001 From: panbingkun Date: Tue, 6 Jun 2023 09:10:42 +0800 Subject: [PATCH 1/3] [SPARK-43935][SQL][PYTHON][CONNECT] Add xpath_* functions to Scala and Python --- .../org/apache/spark/sql/functions.scala | 87 ++++++++++++ .../spark/sql/PlanGenerationTestSuite.scala | 36 +++++ .../explain-results/function_xpath.explain | 2 + .../function_xpath_boolean.explain | 2 + .../function_xpath_double.explain | 2 + .../function_xpath_float.explain | 2 + .../function_xpath_int.explain | 2 + .../function_xpath_long.explain | 2 + .../function_xpath_number.explain | 2 + .../function_xpath_short.explain | 2 + .../function_xpath_string.explain | 2 + .../query-tests/queries/function_xpath.json | 29 ++++ .../queries/function_xpath.proto.bin | Bin 0 -> 136 bytes .../queries/function_xpath_boolean.json | 29 ++++ .../queries/function_xpath_boolean.proto.bin | Bin 0 -> 137 bytes .../queries/function_xpath_double.json | 29 ++++ .../queries/function_xpath_double.proto.bin | Bin 0 -> 136 bytes .../queries/function_xpath_float.json | 29 ++++ .../queries/function_xpath_float.proto.bin | Bin 0 -> 135 bytes .../queries/function_xpath_int.json | 29 ++++ .../queries/function_xpath_int.proto.bin | Bin 0 -> 132 bytes .../queries/function_xpath_long.json | 29 ++++ .../queries/function_xpath_long.proto.bin | Bin 0 -> 133 bytes .../queries/function_xpath_number.json | 29 ++++ .../queries/function_xpath_number.proto.bin | Bin 0 -> 136 bytes .../queries/function_xpath_short.json | 29 ++++ .../queries/function_xpath_short.proto.bin | Bin 0 -> 135 bytes .../queries/function_xpath_string.json | 29 ++++ .../queries/function_xpath_string.proto.bin | Bin 0 -> 136 bytes .../reference/pyspark.sql/functions.rst | 15 +++ python/pyspark/sql/connect/functions.py | 63 +++++++++ python/pyspark/sql/functions.py | 126 ++++++++++++++++++ .../org/apache/spark/sql/functions.scala | 97 ++++++++++++++ .../spark/sql/XPathFunctionsSuite.scala | 17 +++ 34 files changed, 720 insertions(+) create mode 100644 connector/connect/common/src/test/resources/query-tests/explain-results/function_xpath.explain create mode 100644 connector/connect/common/src/test/resources/query-tests/explain-results/function_xpath_boolean.explain create mode 100644 connector/connect/common/src/test/resources/query-tests/explain-results/function_xpath_double.explain create mode 100644 connector/connect/common/src/test/resources/query-tests/explain-results/function_xpath_float.explain create mode 100644 connector/connect/common/src/test/resources/query-tests/explain-results/function_xpath_int.explain create mode 100644 connector/connect/common/src/test/resources/query-tests/explain-results/function_xpath_long.explain create mode 100644 connector/connect/common/src/test/resources/query-tests/explain-results/function_xpath_number.explain create mode 100644 connector/connect/common/src/test/resources/query-tests/explain-results/function_xpath_short.explain create mode 100644 connector/connect/common/src/test/resources/query-tests/explain-results/function_xpath_string.explain create mode 100644 connector/connect/common/src/test/resources/query-tests/queries/function_xpath.json create mode 100644 connector/connect/common/src/test/resources/query-tests/queries/function_xpath.proto.bin create mode 100644 connector/connect/common/src/test/resources/query-tests/queries/function_xpath_boolean.json create mode 100644 connector/connect/common/src/test/resources/query-tests/queries/function_xpath_boolean.proto.bin create mode 100644 connector/connect/common/src/test/resources/query-tests/queries/function_xpath_double.json create mode 100644 connector/connect/common/src/test/resources/query-tests/queries/function_xpath_double.proto.bin create mode 100644 connector/connect/common/src/test/resources/query-tests/queries/function_xpath_float.json create mode 100644 connector/connect/common/src/test/resources/query-tests/queries/function_xpath_float.proto.bin create mode 100644 connector/connect/common/src/test/resources/query-tests/queries/function_xpath_int.json create mode 100644 connector/connect/common/src/test/resources/query-tests/queries/function_xpath_int.proto.bin create mode 100644 connector/connect/common/src/test/resources/query-tests/queries/function_xpath_long.json create mode 100644 connector/connect/common/src/test/resources/query-tests/queries/function_xpath_long.proto.bin create mode 100644 connector/connect/common/src/test/resources/query-tests/queries/function_xpath_number.json create mode 100644 connector/connect/common/src/test/resources/query-tests/queries/function_xpath_number.proto.bin create mode 100644 connector/connect/common/src/test/resources/query-tests/queries/function_xpath_short.json create mode 100644 connector/connect/common/src/test/resources/query-tests/queries/function_xpath_short.proto.bin create mode 100644 connector/connect/common/src/test/resources/query-tests/queries/function_xpath_string.json create mode 100644 connector/connect/common/src/test/resources/query-tests/queries/function_xpath_string.proto.bin diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala index d82edf4a69a62..c7a1f501a8749 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala @@ -4248,6 +4248,93 @@ object functions { def array_except(col1: Column, col2: Column): Column = Column.fn("array_except", col1, col2) + /** + * Returns a string array of values within the nodes of xml that match the XPath expression. + * + * @group "xml_funcs" + * @since 3.5.0 + */ + def xpath(xml: Column, path: String): Column = + Column.fn("xpath", xml, lit(path)) + + /** + * Returns true if the XPath expression evaluates to true, or if a matching node is found. + * + * @group "xml_funcs" + * @since 3.5.0 + */ + def xpath_boolean(xml: Column, path: String): Column = + Column.fn("xpath_boolean", xml, lit(path)) + + /** + * Returns a double value, the value zero if no match is found, or NaN if a match is found but + * the value is non-numeric. + * + * @group "xml_funcs" + * @since 3.5.0 + */ + def xpath_double(xml: Column, path: String): Column = + Column.fn("xpath_double", xml, lit(path)) + + /** + * Returns a double value, the value zero if no match is found, or NaN if a match is found but + * the value is non-numeric. + * + * @group "xml_funcs" + * @since 3.5.0 + */ + def xpath_number(xml: Column, path: String): Column = + Column.fn("xpath_number", xml, lit(path)) + + /** + * Returns a float value, the value zero if no match is found, or NaN if a match is found but + * the value is non-numeric. + * + * @group "xml_funcs" + * @since 3.5.0 + */ + def xpath_float(xml: Column, path: String): Column = + Column.fn("xpath_float", xml, lit(path)) + + /** + * Returns an integer value, or the value zero if no match is found, or a match is found but the + * value is non-numeric. + * + * @group "xml_funcs" + * @since 3.5.0 + */ + def xpath_int(xml: Column, path: String): Column = + Column.fn("xpath_int", xml, lit(path)) + + /** + * Returns a long integer value, or the value zero if no match is found, or a match is found but + * the value is non-numeric. + * + * @group "xml_funcs" + * @since 3.5.0 + */ + def xpath_long(xml: Column, path: String): Column = + Column.fn("xpath_long", xml, lit(path)) + + /** + * Returns a short integer value, or the value zero if no match is found, or a match is found + * but the value is non-numeric. + * + * @group "xml_funcs" + * @since 3.5.0 + */ + def xpath_short(xml: Column, path: String): Column = + Column.fn("xpath_short", xml, lit(path)) + + /** + * Returns the text contents of the first xml node that matches the XPath expression. + * + * @group "xml_funcs" + * @since 3.5.0 + */ + def xpath_string(xml: Column, path: String): Column = + Column.fn("xpath_string", xml, lit(path)) + private def newLambdaVariable(name: String): proto.Expression.UnresolvedNamedLambdaVariable = { proto.Expression.UnresolvedNamedLambdaVariable .newBuilder() diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala index 607db2ee08689..8df0907c4ca2f 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala @@ -1655,6 +1655,42 @@ class PlanGenerationTestSuite fn.to_date(fn.col("s"), "yyyy-MM-dd") } + temporalFunctionTest("xpath") { + fn.xpath(fn.col("s"), "a/b/text()") + } + + temporalFunctionTest("xpath_boolean") { + fn.xpath_boolean(fn.col("s"), "a/b") + } + + temporalFunctionTest("xpath_double") { + fn.xpath_double(fn.col("s"), "a/b") + } + + temporalFunctionTest("xpath_number") { + fn.xpath_number(fn.col("s"), "a/b") + } + + temporalFunctionTest("xpath_float") { + fn.xpath_float(fn.col("s"), "a/b") + } + + temporalFunctionTest("xpath_int") { + fn.xpath_int(fn.col("s"), "a/b") + } + + temporalFunctionTest("xpath_long") { + fn.xpath_long(fn.col("s"), "a/b") + } + + temporalFunctionTest("xpath_short") { + fn.xpath_short(fn.col("s"), "a/b") + } + + temporalFunctionTest("xpath_string") { + fn.xpath_string(fn.col("s"), "a/b") + } + temporalFunctionTest("trunc") { fn.trunc(fn.col("d"), "mm") } diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_xpath.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_xpath.explain new file mode 100644 index 0000000000000..d9e2e55d9b12e --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_xpath.explain @@ -0,0 +1,2 @@ +Project [xpath(s#0, a/b/text()) AS xpath(s, a/b/text())#0] ++- LocalRelation , [d#0, t#0, s#0, x#0L, wt#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_xpath_boolean.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_xpath_boolean.explain new file mode 100644 index 0000000000000..9b75f81802467 --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_xpath_boolean.explain @@ -0,0 +1,2 @@ +Project [xpath_boolean(s#0, a/b) AS xpath_boolean(s, a/b)#0] ++- LocalRelation , [d#0, t#0, s#0, x#0L, wt#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_xpath_double.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_xpath_double.explain new file mode 100644 index 0000000000000..9ce47136df242 --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_xpath_double.explain @@ -0,0 +1,2 @@ +Project [xpath_double(s#0, a/b) AS xpath_double(s, a/b)#0] ++- LocalRelation , [d#0, t#0, s#0, x#0L, wt#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_xpath_float.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_xpath_float.explain new file mode 100644 index 0000000000000..02b29ec4afa9c --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_xpath_float.explain @@ -0,0 +1,2 @@ +Project [xpath_float(s#0, a/b) AS xpath_float(s, a/b)#0] ++- LocalRelation , [d#0, t#0, s#0, x#0L, wt#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_xpath_int.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_xpath_int.explain new file mode 100644 index 0000000000000..cdd56eaa73199 --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_xpath_int.explain @@ -0,0 +1,2 @@ +Project [xpath_int(s#0, a/b) AS xpath_int(s, a/b)#0] ++- LocalRelation , [d#0, t#0, s#0, x#0L, wt#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_xpath_long.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_xpath_long.explain new file mode 100644 index 0000000000000..3acefb13d0f8c --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_xpath_long.explain @@ -0,0 +1,2 @@ +Project [xpath_long(s#0, a/b) AS xpath_long(s, a/b)#0L] ++- LocalRelation , [d#0, t#0, s#0, x#0L, wt#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_xpath_number.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_xpath_number.explain new file mode 100644 index 0000000000000..0a30685f0c6d2 --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_xpath_number.explain @@ -0,0 +1,2 @@ +Project [xpath_number(s#0, a/b) AS xpath_number(s, a/b)#0] ++- LocalRelation , [d#0, t#0, s#0, x#0L, wt#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_xpath_short.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_xpath_short.explain new file mode 100644 index 0000000000000..ed440972bf490 --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_xpath_short.explain @@ -0,0 +1,2 @@ +Project [xpath_short(s#0, a/b) AS xpath_short(s, a/b)#0] ++- LocalRelation , [d#0, t#0, s#0, x#0L, wt#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_xpath_string.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_xpath_string.explain new file mode 100644 index 0000000000000..f4103f68c3bc3 --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_xpath_string.explain @@ -0,0 +1,2 @@ +Project [xpath_string(s#0, a/b) AS xpath_string(s, a/b)#0] ++- LocalRelation , [d#0, t#0, s#0, x#0L, wt#0] diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_xpath.json b/connector/connect/common/src/test/resources/query-tests/queries/function_xpath.json new file mode 100644 index 0000000000000..3dea90a13653d --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/queries/function_xpath.json @@ -0,0 +1,29 @@ +{ + "common": { + "planId": "1" + }, + "project": { + "input": { + "common": { + "planId": "0" + }, + "localRelation": { + "schema": "struct\u003cd:date,t:timestamp,s:string,x:bigint,wt:struct\u003cstart:timestamp,end:timestamp\u003e\u003e" + } + }, + "expressions": [{ + "unresolvedFunction": { + "functionName": "xpath", + "arguments": [{ + "unresolvedAttribute": { + "unparsedIdentifier": "s" + } + }, { + "literal": { + "string": "a/b/text()" + } + }] + } + }] + } +} \ No newline at end of file diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_xpath.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/function_xpath.proto.bin new file mode 100644 index 0000000000000000000000000000000000000000..aabfc76f8a7e1e76d4eb5d79faee5f7442191afb GIT binary patch literal 136 zcmd;L5@3{SWaNtFViI783Kt43E-5NaF0o0mN=Yn9)hV$m$;?eHE=kNS&?&Y8%4O!I z>r_}JWu|B5mFSe0fP_HGfhvlS)u!g9AavTXm!wvdXlMceefuhR literal 0 HcmV?d00001 diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_xpath_boolean.json b/connector/connect/common/src/test/resources/query-tests/queries/function_xpath_boolean.json new file mode 100644 index 0000000000000..793d459ec165b --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/queries/function_xpath_boolean.json @@ -0,0 +1,29 @@ +{ + "common": { + "planId": "1" + }, + "project": { + "input": { + "common": { + "planId": "0" + }, + "localRelation": { + "schema": "struct\u003cd:date,t:timestamp,s:string,x:bigint,wt:struct\u003cstart:timestamp,end:timestamp\u003e\u003e" + } + }, + "expressions": [{ + "unresolvedFunction": { + "functionName": "xpath_boolean", + "arguments": [{ + "unresolvedAttribute": { + "unparsedIdentifier": "s" + } + }, { + "literal": { + "string": "a/b" + } + }] + } + }] + } +} \ No newline at end of file diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_xpath_boolean.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/function_xpath_boolean.proto.bin new file mode 100644 index 0000000000000000000000000000000000000000..544caab4ecc5b012ab8e9465e2b68b23d3fa4d21 GIT binary patch literal 137 zcmd;L5@3{SV&sbDViI783Kt43E-5NaF0o0mN=Yn9)hV$m$;?eHE=kNS&?&Y8%4O!I z>r_}JWu|B5mFSe0fP_HGfhvlS)u!g9AavTr_}JWu|B5mFSe0fP_HGfhvlS)u!g9AavTL&pJouVrK literal 0 HcmV?d00001 diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_xpath_float.json b/connector/connect/common/src/test/resources/query-tests/queries/function_xpath_float.json new file mode 100644 index 0000000000000..94932891225d7 --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/queries/function_xpath_float.json @@ -0,0 +1,29 @@ +{ + "common": { + "planId": "1" + }, + "project": { + "input": { + "common": { + "planId": "0" + }, + "localRelation": { + "schema": "struct\u003cd:date,t:timestamp,s:string,x:bigint,wt:struct\u003cstart:timestamp,end:timestamp\u003e\u003e" + } + }, + "expressions": [{ + "unresolvedFunction": { + "functionName": "xpath_float", + "arguments": [{ + "unresolvedAttribute": { + "unparsedIdentifier": "s" + } + }, { + "literal": { + "string": "a/b" + } + }] + } + }] + } +} \ No newline at end of file diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_xpath_float.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/function_xpath_float.proto.bin new file mode 100644 index 0000000000000000000000000000000000000000..32dfbc00cfa44ff9e462fc4898d6f30b5eca3f36 GIT binary patch literal 135 zcmd;L5@3{SVC0JBViI783Kt43E-5NaF0o0mN=Yn9)hV$m$;?eHE=kNS&?&Y8%4O!I z>r_}JWu|B5mFSe0fP_HGfhvlS)u!g9AavT^5% O3$b&tW-%w~CjkI$n<|w6 literal 0 HcmV?d00001 diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_xpath_int.json b/connector/connect/common/src/test/resources/query-tests/queries/function_xpath_int.json new file mode 100644 index 0000000000000..0dcef00ed20d4 --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/queries/function_xpath_int.json @@ -0,0 +1,29 @@ +{ + "common": { + "planId": "1" + }, + "project": { + "input": { + "common": { + "planId": "0" + }, + "localRelation": { + "schema": "struct\u003cd:date,t:timestamp,s:string,x:bigint,wt:struct\u003cstart:timestamp,end:timestamp\u003e\u003e" + } + }, + "expressions": [{ + "unresolvedFunction": { + "functionName": "xpath_int", + "arguments": [{ + "unresolvedAttribute": { + "unparsedIdentifier": "s" + } + }, { + "literal": { + "string": "a/b" + } + }] + } + }] + } +} \ No newline at end of file diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_xpath_int.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/function_xpath_int.proto.bin new file mode 100644 index 0000000000000000000000000000000000000000..e6298b37dbe36508e77927aed8b061102cdb5d17 GIT binary patch literal 132 zcmd;L5@3|7^5%3$b&t KW-%w~CjkHr%qhVD literal 0 HcmV?d00001 diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_xpath_long.json b/connector/connect/common/src/test/resources/query-tests/queries/function_xpath_long.json new file mode 100644 index 0000000000000..c740d2bad4f5f --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/queries/function_xpath_long.json @@ -0,0 +1,29 @@ +{ + "common": { + "planId": "1" + }, + "project": { + "input": { + "common": { + "planId": "0" + }, + "localRelation": { + "schema": "struct\u003cd:date,t:timestamp,s:string,x:bigint,wt:struct\u003cstart:timestamp,end:timestamp\u003e\u003e" + } + }, + "expressions": [{ + "unresolvedFunction": { + "functionName": "xpath_long", + "arguments": [{ + "unresolvedAttribute": { + "unparsedIdentifier": "s" + } + }, { + "literal": { + "string": "a/b" + } + }] + } + }] + } +} \ No newline at end of file diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_xpath_long.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/function_xpath_long.proto.bin new file mode 100644 index 0000000000000000000000000000000000000000..d240600eabbae5b72b07796cb51722330aae37e6 GIT binary patch literal 133 zcmd;L5@3|7=ZfWG5@3i57YZ#dDJo4au}QH?Ni0d#DX}Wa%uOvWNz5(KDYgR2W#*;p zR9Gctrf24r=#-a$gh0xHDvFTRrsky}blTZT$w|p@aa9x~mSn`|r_}JWu|B5mFSe0fP_HGfhvlS)u!g9AavTL&pJp1Ugy literal 0 HcmV?d00001 diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_xpath_short.json b/connector/connect/common/src/test/resources/query-tests/queries/function_xpath_short.json new file mode 100644 index 0000000000000..5d3a3e9983707 --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/queries/function_xpath_short.json @@ -0,0 +1,29 @@ +{ + "common": { + "planId": "1" + }, + "project": { + "input": { + "common": { + "planId": "0" + }, + "localRelation": { + "schema": "struct\u003cd:date,t:timestamp,s:string,x:bigint,wt:struct\u003cstart:timestamp,end:timestamp\u003e\u003e" + } + }, + "expressions": [{ + "unresolvedFunction": { + "functionName": "xpath_short", + "arguments": [{ + "unresolvedAttribute": { + "unparsedIdentifier": "s" + } + }, { + "literal": { + "string": "a/b" + } + }] + } + }] + } +} \ No newline at end of file diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_xpath_short.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/function_xpath_short.proto.bin new file mode 100644 index 0000000000000000000000000000000000000000..9ae27bd973853ebfa33e67346c342d9e75d84e64 GIT binary patch literal 135 zcmd;L5@3{SVC0JBViI783Kt43E-5NaF0o0mN=Yn9)hV$m$;?eHE=kNS&?&Y8%4O!I z>r_}JWu|B5mFSe0fP_HGfhvlS)u!g9AavT^5% O3$b&tW-%w~CjkI&jw-JJ literal 0 HcmV?d00001 diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_xpath_string.json b/connector/connect/common/src/test/resources/query-tests/queries/function_xpath_string.json new file mode 100644 index 0000000000000..26e4130ae2c4b --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/queries/function_xpath_string.json @@ -0,0 +1,29 @@ +{ + "common": { + "planId": "1" + }, + "project": { + "input": { + "common": { + "planId": "0" + }, + "localRelation": { + "schema": "struct\u003cd:date,t:timestamp,s:string,x:bigint,wt:struct\u003cstart:timestamp,end:timestamp\u003e\u003e" + } + }, + "expressions": [{ + "unresolvedFunction": { + "functionName": "xpath_string", + "arguments": [{ + "unresolvedAttribute": { + "unparsedIdentifier": "s" + } + }, { + "literal": { + "string": "a/b" + } + }] + } + }] + } +} \ No newline at end of file diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_xpath_string.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/function_xpath_string.proto.bin new file mode 100644 index 0000000000000000000000000000000000000000..5384301238b1ea9eafa625a619cced9f2ef70e06 GIT binary patch literal 136 zcmd;L5@3{SWaNtFViI783Kt43E-5NaF0o0mN=Yn9)hV$m$;?eHE=kNS&?&Y8%4O!I z>r_}JWu|B5mFSe0fP_HGfhvlS)u!g9AavT Column: to_timestamp.__doc__ = pysparkfuncs.to_timestamp.__doc__ +def xpath(xml: "ColumnOrName", path: str) -> Column: + return _invoke_function("xpath", _to_col(xml), lit(path)) + + +xpath.__doc__ = pysparkfuncs.xpath.__doc__ + + +def xpath_boolean(xml: "ColumnOrName", path: str) -> Column: + return _invoke_function("xpath_boolean", _to_col(xml), lit(path)) + + +xpath_boolean.__doc__ = pysparkfuncs.xpath_boolean.__doc__ + + +def xpath_double(xml: "ColumnOrName", path: str) -> Column: + return _invoke_function("xpath_double", _to_col(xml), lit(path)) + + +xpath_double.__doc__ = pysparkfuncs.xpath_double.__doc__ + + +def xpath_number(xml: "ColumnOrName", path: str) -> Column: + return _invoke_function("xpath_number", _to_col(xml), lit(path)) + + +xpath_number.__doc__ = pysparkfuncs.xpath_number.__doc__ + + +def xpath_float(xml: "ColumnOrName", path: str) -> Column: + return _invoke_function("xpath_float", _to_col(xml), lit(path)) + + +xpath_float.__doc__ = pysparkfuncs.xpath_float.__doc__ + + +def xpath_int(xml: "ColumnOrName", path: str) -> Column: + return _invoke_function("xpath_int", _to_col(xml), lit(path)) + + +xpath_int.__doc__ = pysparkfuncs.xpath_int.__doc__ + + +def xpath_long(xml: "ColumnOrName", path: str) -> Column: + return _invoke_function("xpath_long", _to_col(xml), lit(path)) + + +xpath_long.__doc__ = pysparkfuncs.xpath_long.__doc__ + + +def xpath_short(xml: "ColumnOrName", path: str) -> Column: + return _invoke_function("xpath_short", _to_col(xml), lit(path)) + + +xpath_short.__doc__ = pysparkfuncs.xpath_short.__doc__ + + +def xpath_string(xml: "ColumnOrName", path: str) -> Column: + return _invoke_function("xpath_string", _to_col(xml), lit(path)) + + +xpath_string.__doc__ = pysparkfuncs.xpath_string.__doc__ + + def trunc(date: "ColumnOrName", format: str) -> Column: return _invoke_function("trunc", _to_col(date), lit(format)) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index fe35f12c40215..61aa06c14606e 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -5033,6 +5033,132 @@ def to_timestamp(col: "ColumnOrName", format: Optional[str] = None) -> Column: return _invoke_function("to_timestamp", _to_java_column(col), format) +@try_remote_functions +def xpath(xml: "ColumnOrName", path: str) -> Column: + """ + Returns a string array of values within the nodes of xml that match the XPath expression. + + Examples + -------- + >>> df = spark.createDataFrame([('b1b2b3c1c2',)], ['x']) + >>> df.select(xpath(df.x, 'a/b/text()').alias('r')).collect() + [Row(r=['b1', 'b2', 'b3'])] + """ + return _invoke_function("xpath", _to_java_column(xml), path) + + +@try_remote_functions +def xpath_boolean(xml: "ColumnOrName", path: str) -> Column: + """ + Returns true if the XPath expression evaluates to true, or if a matching node is found. + + Examples + -------- + >>> df = spark.createDataFrame([('1',)], ['x']) + >>> df.select(xpath_boolean(df.x, 'a/b').alias('r')).collect() + [Row(r=True)] + """ + return _invoke_function("xpath_boolean", _to_java_column(xml), path) + + +@try_remote_functions +def xpath_double(xml: "ColumnOrName", path: str) -> Column: + """ + Returns a double value, the value zero if no match is found, or NaN if a match is found but the value is non-numeric. + + Examples + -------- + >>> df = spark.createDataFrame([('12',)], ['x']) + >>> df.select(xpath_double(df.x, 'sum(a/b)').alias('r')).collect() + [Row(r=3.0)] + """ + return _invoke_function("xpath_double", _to_java_column(xml), path) + + +@try_remote_functions +def xpath_number(xml: "ColumnOrName", path: str) -> Column: + """ + Returns a double value, the value zero if no match is found, or NaN if a match is found but the value is non-numeric. + + Examples + -------- + >>> df = spark.createDataFrame([('12',)], ['x']) + >>> df.select(xpath_number(df.x, 'sum(a/b)').alias('r')).collect() + [Row(r=3.0)] + """ + return _invoke_function("xpath_number", _to_java_column(xml), path) + + +@try_remote_functions +def xpath_float(xml: "ColumnOrName", path: str) -> Column: + """ + Returns a float value, the value zero if no match is found, or NaN if a match is found but the value is non-numeric. + + Examples + -------- + >>> df = spark.createDataFrame([('12',)], ['x']) + >>> df.select(xpath_float(df.x, 'sum(a/b)').alias('r')).collect() + [Row(r=3.0)] + """ + return _invoke_function("xpath_float", _to_java_column(xml), path) + + +@try_remote_functions +def xpath_int(xml: "ColumnOrName", path: str) -> Column: + """ + Returns an integer value, or the value zero if no match is found, or a match is found but the value is non-numeric. + + Examples + -------- + >>> df = spark.createDataFrame([('12',)], ['x']) + >>> df.select(xpath_int(df.x, 'sum(a/b)').alias('r')).collect() + [Row(r=3)] + """ + return _invoke_function("xpath_int", _to_java_column(xml), path) + + +@try_remote_functions +def xpath_long(xml: "ColumnOrName", path: str) -> Column: + """ + Returns a long integer value, or the value zero if no match is found, or a match is found but the value is non-numeric. + + Examples + -------- + >>> df = spark.createDataFrame([('12',)], ['x']) + >>> df.select(xpath_long(df.x, 'sum(a/b)').alias('r')).collect() + [Row(r=3)] + """ + return _invoke_function("xpath_long", _to_java_column(xml), path) + + +@try_remote_functions +def xpath_short(xml: "ColumnOrName", path: str) -> Column: + """ + Returns a short integer value, or the value zero if no match is found, or a match is found but the value is non-numeric. + + Examples + -------- + >>> df = spark.createDataFrame([('12',)], ['x']) + >>> df.select(xpath_short(df.x, 'sum(a/b)').alias('r')).collect() + [Row(r=3)] + """ + return _invoke_function("xpath_short", _to_java_column(xml), path) + + +@try_remote_functions +def xpath_string(xml: "ColumnOrName", path: str) -> Column: + """ + Returns the text contents of the first xml node that matches the XPath expression. + + Examples + -------- + >>> df = spark.createDataFrame([('bcc',)], ['x']) + >>> df.select(xpath_string(df.x, 'a/c').alias('r')).collect() + [Row(r='cc')] + """ + return _invoke_function("xpath_string", _to_java_column(xml), path) + + @try_remote_functions def trunc(date: "ColumnOrName", format: str) -> Column: """ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 130614d342f2a..eb5722eadf0ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.expressions.xml._ import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, HintInfo, ResolvedHint} import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, TimestampFormatter} import org.apache.spark.sql.errors.QueryCompilationErrors @@ -5243,6 +5244,102 @@ object functions { def days(e: Column): Column = withExpr { Days(e.expr) } /** + * Returns a string array of values within the nodes of xml that match the XPath expression. + * + * @group "xml_funcs" + * @since 3.5.0 + */ + def xpath(x: Column, p: String): Column = withExpr { + XPathList(x.expr, lit(p).expr) + } + + /** + * Returns true if the XPath expression evaluates to true, or if a matching node is found. + * + * @group "xml_funcs" + * @since 3.5.0 + */ + def xpath_boolean(x: Column, p: String): Column = withExpr { + XPathBoolean(x.expr, lit(p).expr) + } + + /** + * Returns a double value, the value zero if no match is found, + * or NaN if a match is found but the value is non-numeric. + * + * @group "xml_funcs" + * @since 3.5.0 + */ + def xpath_double(x: Column, p: String): Column = withExpr { + XPathDouble(x.expr, lit(p).expr) + } + + /** + * Returns a double value, the value zero if no match is found, + * or NaN if a match is found but the value is non-numeric. + * + * @group "xml_funcs" + * @since 3.5.0 + */ + def xpath_number(x: Column, p: String): Column = withExpr { + XPathDouble(x.expr, lit(p).expr) + } + + /** + * Returns a float value, the value zero if no match is found, + * or NaN if a match is found but the value is non-numeric. + * + * @group "xml_funcs" + * @since 3.5.0 + */ + def xpath_float(x: Column, p: String): Column = withExpr { + XPathFloat(x.expr, lit(p).expr) + } + + /** + * Returns an integer value, or the value zero if no match is found, + * or a match is found but the value is non-numeric. + * + * @group "xml_funcs" + * @since 3.5.0 + */ + def xpath_int(x: Column, p: String): Column = withExpr { + XPathInt(x.expr, lit(p).expr) + } + + /** + * Returns a long integer value, or the value zero if no match is found, + * or a match is found but the value is non-numeric. + * + * @group "xml_funcs" + * @since 3.5.0 + */ + def xpath_long(x: Column, p: String): Column = withExpr { + XPathLong(x.expr, lit(p).expr) + } + + /** + * Returns a short integer value, or the value zero if no match is found, + * or a match is found but the value is non-numeric. + * + * @group "xml_funcs" + * @since 3.5.0 + */ + def xpath_short(x: Column, p: String): Column = withExpr { + XPathShort(x.expr, lit(p).expr) + } + + /** + * Returns the text contents of the first xml node that matches the XPath expression. + * + * @group "xml_funcs" + * @since 3.5.0 + */ + def xpath_string(x: Column, p: String): Column = withExpr { + XPathString(x.expr, lit(p).expr) + } + + /** * A transform for timestamps to partition data into hours. * * @group partition_transforms diff --git a/sql/core/src/test/scala/org/apache/spark/sql/XPathFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/XPathFunctionsSuite.scala index a25cca7af50bd..cc2d3ba5ef4c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/XPathFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/XPathFunctionsSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSparkSession /** @@ -28,6 +29,7 @@ class XPathFunctionsSuite extends QueryTest with SharedSparkSession { test("xpath_boolean") { val df = Seq("b").toDF("xml") checkAnswer(df.selectExpr("xpath_boolean(xml, 'a/b')"), Row(true)) + checkAnswer(df.select(xpath_boolean(col("xml"), "a/b")), Row(true)) } test("xpath_short, xpath_int, xpath_long") { @@ -38,6 +40,12 @@ class XPathFunctionsSuite extends QueryTest with SharedSparkSession { "xpath_int(xml, 'sum(a/b)')", "xpath_long(xml, 'sum(a/b)')"), Row(3.toShort, 3, 3L)) + checkAnswer( + df.select( + xpath_short(col("xml"), "sum(a/b)"), + xpath_int(col("xml"), "sum(a/b)"), + xpath_long(col("xml"), "sum(a/b)")), + Row(3.toShort, 3, 3L)) } test("xpath_float, xpath_double, xpath_number") { @@ -48,15 +56,24 @@ class XPathFunctionsSuite extends QueryTest with SharedSparkSession { "xpath_double(xml, 'sum(a/b)')", "xpath_number(xml, 'sum(a/b)')"), Row(3.1.toFloat, 3.1, 3.1)) + checkAnswer( + df.select( + xpath_float(col("xml"), "sum(a/b)"), + xpath_double(col("xml"), "sum(a/b)"), + xpath_number(col("xml"), "sum(a/b)")), + Row(3.1.toFloat, 3.1, 3.1)) } test("xpath_string") { val df = Seq("bcc").toDF("xml") checkAnswer(df.selectExpr("xpath_string(xml, 'a/c')"), Row("cc")) + checkAnswer(df.select(xpath_string(col("xml"), "a/c")), Row("cc")) } test("xpath") { val df = Seq("b1b2b3c1c2").toDF("xml") checkAnswer(df.selectExpr("xpath(xml, 'a/*/text()')"), Row(Seq("b1", "b2", "b3", "c1", "c2"))) + checkAnswer(df.select(xpath(col("xml"), "a/*/text()")), + Row(Seq("b1", "b2", "b3", "c1", "c2"))) } } From 62f7eee7b761054cef3a690552819709c3eb4c50 Mon Sep 17 00:00:00 2001 From: panbingkun Date: Tue, 6 Jun 2023 14:27:57 +0800 Subject: [PATCH 2/3] [SPARK-43935][SQL][PYTHON][CONNECT] Add xpath_* functions to Scala and Python --- .../org/apache/spark/sql/functions.scala | 36 ++++++------ .../spark/sql/PlanGenerationTestSuite.scala | 18 +++--- python/pyspark/sql/connect/functions.py | 36 ++++++------ python/pyspark/sql/functions.py | 57 ++++++++++--------- .../org/apache/spark/sql/functions.scala | 36 ++++++------ .../spark/sql/XPathFunctionsSuite.scala | 18 +++--- 6 files changed, 101 insertions(+), 100 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala index c7a1f501a8749..bc18d24059333 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala @@ -4254,8 +4254,8 @@ object functions { * @group "xml_funcs" * @since 3.5.0 */ - def xpath(xml: Column, path: String): Column = - Column.fn("xpath", xml, lit(path)) + def xpath(xml: Column, path: Column): Column = + Column.fn("xpath", xml, path) /** * Returns true if the XPath expression evaluates to true, or if a matching node is found. @@ -4263,8 +4263,8 @@ object functions { * @group "xml_funcs" * @since 3.5.0 */ - def xpath_boolean(xml: Column, path: String): Column = - Column.fn("xpath_boolean", xml, lit(path)) + def xpath_boolean(xml: Column, path: Column): Column = + Column.fn("xpath_boolean", xml, path) /** * Returns a double value, the value zero if no match is found, or NaN if a match is found but @@ -4273,8 +4273,8 @@ object functions { * @group "xml_funcs" * @since 3.5.0 */ - def xpath_double(xml: Column, path: String): Column = - Column.fn("xpath_double", xml, lit(path)) + def xpath_double(xml: Column, path: Column): Column = + Column.fn("xpath_double", xml, path) /** * Returns a double value, the value zero if no match is found, or NaN if a match is found but @@ -4283,8 +4283,8 @@ object functions { * @group "xml_funcs" * @since 3.5.0 */ - def xpath_number(xml: Column, path: String): Column = - Column.fn("xpath_number", xml, lit(path)) + def xpath_number(xml: Column, path: Column): Column = + Column.fn("xpath_number", xml, path) /** * Returns a float value, the value zero if no match is found, or NaN if a match is found but @@ -4293,8 +4293,8 @@ object functions { * @group "xml_funcs" * @since 3.5.0 */ - def xpath_float(xml: Column, path: String): Column = - Column.fn("xpath_float", xml, lit(path)) + def xpath_float(xml: Column, path: Column): Column = + Column.fn("xpath_float", xml, path) /** * Returns an integer value, or the value zero if no match is found, or a match is found but the @@ -4303,8 +4303,8 @@ object functions { * @group "xml_funcs" * @since 3.5.0 */ - def xpath_int(xml: Column, path: String): Column = - Column.fn("xpath_int", xml, lit(path)) + def xpath_int(xml: Column, path: Column): Column = + Column.fn("xpath_int", xml, path) /** * Returns a long integer value, or the value zero if no match is found, or a match is found but @@ -4313,8 +4313,8 @@ object functions { * @group "xml_funcs" * @since 3.5.0 */ - def xpath_long(xml: Column, path: String): Column = - Column.fn("xpath_long", xml, lit(path)) + def xpath_long(xml: Column, path: Column): Column = + Column.fn("xpath_long", xml, path) /** * Returns a short integer value, or the value zero if no match is found, or a match is found @@ -4323,8 +4323,8 @@ object functions { * @group "xml_funcs" * @since 3.5.0 */ - def xpath_short(xml: Column, path: String): Column = - Column.fn("xpath_short", xml, lit(path)) + def xpath_short(xml: Column, path: Column): Column = + Column.fn("xpath_short", xml, path) /** * Returns the text contents of the first xml node that matches the XPath expression. @@ -4332,8 +4332,8 @@ object functions { * @group "xml_funcs" * @since 3.5.0 */ - def xpath_string(xml: Column, path: String): Column = - Column.fn("xpath_string", xml, lit(path)) + def xpath_string(xml: Column, path: Column): Column = + Column.fn("xpath_string", xml, path) private def newLambdaVariable(name: String): proto.Expression.UnresolvedNamedLambdaVariable = { proto.Expression.UnresolvedNamedLambdaVariable diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala index 8df0907c4ca2f..21d459012e1cf 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala @@ -1656,39 +1656,39 @@ class PlanGenerationTestSuite } temporalFunctionTest("xpath") { - fn.xpath(fn.col("s"), "a/b/text()") + fn.xpath(fn.col("s"), lit("a/b/text()")) } temporalFunctionTest("xpath_boolean") { - fn.xpath_boolean(fn.col("s"), "a/b") + fn.xpath_boolean(fn.col("s"), lit("a/b")) } temporalFunctionTest("xpath_double") { - fn.xpath_double(fn.col("s"), "a/b") + fn.xpath_double(fn.col("s"), lit("a/b")) } temporalFunctionTest("xpath_number") { - fn.xpath_number(fn.col("s"), "a/b") + fn.xpath_number(fn.col("s"), lit("a/b")) } temporalFunctionTest("xpath_float") { - fn.xpath_float(fn.col("s"), "a/b") + fn.xpath_float(fn.col("s"), lit("a/b")) } temporalFunctionTest("xpath_int") { - fn.xpath_int(fn.col("s"), "a/b") + fn.xpath_int(fn.col("s"), lit("a/b")) } temporalFunctionTest("xpath_long") { - fn.xpath_long(fn.col("s"), "a/b") + fn.xpath_long(fn.col("s"), lit("a/b")) } temporalFunctionTest("xpath_short") { - fn.xpath_short(fn.col("s"), "a/b") + fn.xpath_short(fn.col("s"), lit("a/b")) } temporalFunctionTest("xpath_string") { - fn.xpath_string(fn.col("s"), "a/b") + fn.xpath_string(fn.col("s"), lit("a/b")) } temporalFunctionTest("trunc") { diff --git a/python/pyspark/sql/connect/functions.py b/python/pyspark/sql/connect/functions.py index ae4b9f60e2064..583c520eb479f 100644 --- a/python/pyspark/sql/connect/functions.py +++ b/python/pyspark/sql/connect/functions.py @@ -2168,64 +2168,64 @@ def to_timestamp(col: "ColumnOrName", format: Optional[str] = None) -> Column: to_timestamp.__doc__ = pysparkfuncs.to_timestamp.__doc__ -def xpath(xml: "ColumnOrName", path: str) -> Column: - return _invoke_function("xpath", _to_col(xml), lit(path)) +def xpath(xml: "ColumnOrName", path: "ColumnOrName") -> Column: + return _invoke_function_over_columns("xpath", xml, path) xpath.__doc__ = pysparkfuncs.xpath.__doc__ -def xpath_boolean(xml: "ColumnOrName", path: str) -> Column: - return _invoke_function("xpath_boolean", _to_col(xml), lit(path)) +def xpath_boolean(xml: "ColumnOrName", path: "ColumnOrName") -> Column: + return _invoke_function_over_columns("xpath_boolean", xml, path) xpath_boolean.__doc__ = pysparkfuncs.xpath_boolean.__doc__ -def xpath_double(xml: "ColumnOrName", path: str) -> Column: - return _invoke_function("xpath_double", _to_col(xml), lit(path)) +def xpath_double(xml: "ColumnOrName", path: "ColumnOrName") -> Column: + return _invoke_function_over_columns("xpath_double", xml, path) xpath_double.__doc__ = pysparkfuncs.xpath_double.__doc__ -def xpath_number(xml: "ColumnOrName", path: str) -> Column: - return _invoke_function("xpath_number", _to_col(xml), lit(path)) +def xpath_number(xml: "ColumnOrName", path: "ColumnOrName") -> Column: + return _invoke_function_over_columns("xpath_number", xml, path) xpath_number.__doc__ = pysparkfuncs.xpath_number.__doc__ -def xpath_float(xml: "ColumnOrName", path: str) -> Column: - return _invoke_function("xpath_float", _to_col(xml), lit(path)) +def xpath_float(xml: "ColumnOrName", path: "ColumnOrName") -> Column: + return _invoke_function_over_columns("xpath_float", xml, path) xpath_float.__doc__ = pysparkfuncs.xpath_float.__doc__ -def xpath_int(xml: "ColumnOrName", path: str) -> Column: - return _invoke_function("xpath_int", _to_col(xml), lit(path)) +def xpath_int(xml: "ColumnOrName", path: "ColumnOrName") -> Column: + return _invoke_function_over_columns("xpath_int", xml, path) xpath_int.__doc__ = pysparkfuncs.xpath_int.__doc__ -def xpath_long(xml: "ColumnOrName", path: str) -> Column: - return _invoke_function("xpath_long", _to_col(xml), lit(path)) +def xpath_long(xml: "ColumnOrName", path: "ColumnOrName") -> Column: + return _invoke_function_over_columns("xpath_long", xml, path) xpath_long.__doc__ = pysparkfuncs.xpath_long.__doc__ -def xpath_short(xml: "ColumnOrName", path: str) -> Column: - return _invoke_function("xpath_short", _to_col(xml), lit(path)) +def xpath_short(xml: "ColumnOrName", path: "ColumnOrName") -> Column: + return _invoke_function_over_columns("xpath_short", xml, path) xpath_short.__doc__ = pysparkfuncs.xpath_short.__doc__ -def xpath_string(xml: "ColumnOrName", path: str) -> Column: - return _invoke_function("xpath_string", _to_col(xml), lit(path)) +def xpath_string(xml: "ColumnOrName", path: "ColumnOrName") -> Column: + return _invoke_function_over_columns("xpath_string", xml, path) xpath_string.__doc__ = pysparkfuncs.xpath_string.__doc__ diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 61aa06c14606e..8e4558dc4bca3 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -5034,129 +5034,130 @@ def to_timestamp(col: "ColumnOrName", format: Optional[str] = None) -> Column: @try_remote_functions -def xpath(xml: "ColumnOrName", path: str) -> Column: +def xpath(xml: "ColumnOrName", path: "ColumnOrName") -> Column: """ Returns a string array of values within the nodes of xml that match the XPath expression. Examples -------- - >>> df = spark.createDataFrame([('b1b2b3c1c2',)], ['x']) - >>> df.select(xpath(df.x, 'a/b/text()').alias('r')).collect() + >>> df = spark.createDataFrame( + ... [('b1b2b3c1c2',)], ['x']) + >>> df.select(xpath(df.x, lit('a/b/text()')).alias('r')).collect() [Row(r=['b1', 'b2', 'b3'])] """ - return _invoke_function("xpath", _to_java_column(xml), path) + return _invoke_function_over_columns("xpath", xml, path) @try_remote_functions -def xpath_boolean(xml: "ColumnOrName", path: str) -> Column: +def xpath_boolean(xml: "ColumnOrName", path: "ColumnOrName") -> Column: """ Returns true if the XPath expression evaluates to true, or if a matching node is found. Examples -------- >>> df = spark.createDataFrame([('1',)], ['x']) - >>> df.select(xpath_boolean(df.x, 'a/b').alias('r')).collect() + >>> df.select(xpath_boolean(df.x, lit('a/b')).alias('r')).collect() [Row(r=True)] """ - return _invoke_function("xpath_boolean", _to_java_column(xml), path) + return _invoke_function_over_columns("xpath_boolean", xml, path) @try_remote_functions -def xpath_double(xml: "ColumnOrName", path: str) -> Column: +def xpath_double(xml: "ColumnOrName", path: "ColumnOrName") -> Column: """ Returns a double value, the value zero if no match is found, or NaN if a match is found but the value is non-numeric. Examples -------- >>> df = spark.createDataFrame([('12',)], ['x']) - >>> df.select(xpath_double(df.x, 'sum(a/b)').alias('r')).collect() + >>> df.select(xpath_double(df.x, lit('sum(a/b)')).alias('r')).collect() [Row(r=3.0)] """ - return _invoke_function("xpath_double", _to_java_column(xml), path) + return _invoke_function_over_columns("xpath_double", xml, path) @try_remote_functions -def xpath_number(xml: "ColumnOrName", path: str) -> Column: +def xpath_number(xml: "ColumnOrName", path: "ColumnOrName") -> Column: """ Returns a double value, the value zero if no match is found, or NaN if a match is found but the value is non-numeric. Examples -------- >>> df = spark.createDataFrame([('12',)], ['x']) - >>> df.select(xpath_number(df.x, 'sum(a/b)').alias('r')).collect() + >>> df.select(xpath_number(df.x, lit('sum(a/b)')).alias('r')).collect() [Row(r=3.0)] """ - return _invoke_function("xpath_number", _to_java_column(xml), path) + return _invoke_function_over_columns("xpath_number", xml, path) @try_remote_functions -def xpath_float(xml: "ColumnOrName", path: str) -> Column: +def xpath_float(xml: "ColumnOrName", path: "ColumnOrName") -> Column: """ Returns a float value, the value zero if no match is found, or NaN if a match is found but the value is non-numeric. Examples -------- >>> df = spark.createDataFrame([('12',)], ['x']) - >>> df.select(xpath_float(df.x, 'sum(a/b)').alias('r')).collect() + >>> df.select(xpath_float(df.x, lit('sum(a/b)')).alias('r')).collect() [Row(r=3.0)] """ - return _invoke_function("xpath_float", _to_java_column(xml), path) + return _invoke_function_over_columns("xpath_float", xml, path) @try_remote_functions -def xpath_int(xml: "ColumnOrName", path: str) -> Column: +def xpath_int(xml: "ColumnOrName", path: "ColumnOrName") -> Column: """ Returns an integer value, or the value zero if no match is found, or a match is found but the value is non-numeric. Examples -------- >>> df = spark.createDataFrame([('12',)], ['x']) - >>> df.select(xpath_int(df.x, 'sum(a/b)').alias('r')).collect() + >>> df.select(xpath_int(df.x, lit('sum(a/b)')).alias('r')).collect() [Row(r=3)] """ - return _invoke_function("xpath_int", _to_java_column(xml), path) + return _invoke_function_over_columns("xpath_int", xml, path) @try_remote_functions -def xpath_long(xml: "ColumnOrName", path: str) -> Column: +def xpath_long(xml: "ColumnOrName", path: "ColumnOrName") -> Column: """ Returns a long integer value, or the value zero if no match is found, or a match is found but the value is non-numeric. Examples -------- >>> df = spark.createDataFrame([('12',)], ['x']) - >>> df.select(xpath_long(df.x, 'sum(a/b)').alias('r')).collect() + >>> df.select(xpath_long(df.x, lit('sum(a/b)')).alias('r')).collect() [Row(r=3)] """ - return _invoke_function("xpath_long", _to_java_column(xml), path) + return _invoke_function_over_columns("xpath_long", xml, path) @try_remote_functions -def xpath_short(xml: "ColumnOrName", path: str) -> Column: +def xpath_short(xml: "ColumnOrName", path: "ColumnOrName") -> Column: """ Returns a short integer value, or the value zero if no match is found, or a match is found but the value is non-numeric. Examples -------- >>> df = spark.createDataFrame([('12',)], ['x']) - >>> df.select(xpath_short(df.x, 'sum(a/b)').alias('r')).collect() + >>> df.select(xpath_short(df.x, lit('sum(a/b)')).alias('r')).collect() [Row(r=3)] """ - return _invoke_function("xpath_short", _to_java_column(xml), path) + return _invoke_function_over_columns("xpath_short", xml, path) @try_remote_functions -def xpath_string(xml: "ColumnOrName", path: str) -> Column: +def xpath_string(xml: "ColumnOrName", path: "ColumnOrName") -> Column: """ Returns the text contents of the first xml node that matches the XPath expression. Examples -------- >>> df = spark.createDataFrame([('bcc',)], ['x']) - >>> df.select(xpath_string(df.x, 'a/c').alias('r')).collect() + >>> df.select(xpath_string(df.x, lit('a/c')).alias('r')).collect() [Row(r='cc')] """ - return _invoke_function("xpath_string", _to_java_column(xml), path) + return _invoke_function_over_columns("xpath_string", xml, path) @try_remote_functions diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index eb5722eadf0ef..3303585bec9ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -5249,8 +5249,8 @@ object functions { * @group "xml_funcs" * @since 3.5.0 */ - def xpath(x: Column, p: String): Column = withExpr { - XPathList(x.expr, lit(p).expr) + def xpath(x: Column, p: Column): Column = withExpr { + XPathList(x.expr, p.expr) } /** @@ -5259,8 +5259,8 @@ object functions { * @group "xml_funcs" * @since 3.5.0 */ - def xpath_boolean(x: Column, p: String): Column = withExpr { - XPathBoolean(x.expr, lit(p).expr) + def xpath_boolean(x: Column, p: Column): Column = withExpr { + XPathBoolean(x.expr, p.expr) } /** @@ -5270,8 +5270,8 @@ object functions { * @group "xml_funcs" * @since 3.5.0 */ - def xpath_double(x: Column, p: String): Column = withExpr { - XPathDouble(x.expr, lit(p).expr) + def xpath_double(x: Column, p: Column): Column = withExpr { + XPathDouble(x.expr, p.expr) } /** @@ -5281,8 +5281,8 @@ object functions { * @group "xml_funcs" * @since 3.5.0 */ - def xpath_number(x: Column, p: String): Column = withExpr { - XPathDouble(x.expr, lit(p).expr) + def xpath_number(x: Column, p: Column): Column = withExpr { + XPathDouble(x.expr, p.expr) } /** @@ -5292,8 +5292,8 @@ object functions { * @group "xml_funcs" * @since 3.5.0 */ - def xpath_float(x: Column, p: String): Column = withExpr { - XPathFloat(x.expr, lit(p).expr) + def xpath_float(x: Column, p: Column): Column = withExpr { + XPathFloat(x.expr, p.expr) } /** @@ -5303,8 +5303,8 @@ object functions { * @group "xml_funcs" * @since 3.5.0 */ - def xpath_int(x: Column, p: String): Column = withExpr { - XPathInt(x.expr, lit(p).expr) + def xpath_int(x: Column, p: Column): Column = withExpr { + XPathInt(x.expr, p.expr) } /** @@ -5314,8 +5314,8 @@ object functions { * @group "xml_funcs" * @since 3.5.0 */ - def xpath_long(x: Column, p: String): Column = withExpr { - XPathLong(x.expr, lit(p).expr) + def xpath_long(x: Column, p: Column): Column = withExpr { + XPathLong(x.expr, p.expr) } /** @@ -5325,8 +5325,8 @@ object functions { * @group "xml_funcs" * @since 3.5.0 */ - def xpath_short(x: Column, p: String): Column = withExpr { - XPathShort(x.expr, lit(p).expr) + def xpath_short(x: Column, p: Column): Column = withExpr { + XPathShort(x.expr, p.expr) } /** @@ -5335,8 +5335,8 @@ object functions { * @group "xml_funcs" * @since 3.5.0 */ - def xpath_string(x: Column, p: String): Column = withExpr { - XPathString(x.expr, lit(p).expr) + def xpath_string(x: Column, p: Column): Column = withExpr { + XPathString(x.expr, p.expr) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/XPathFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/XPathFunctionsSuite.scala index cc2d3ba5ef4c0..f08466e8f8d9d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/XPathFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/XPathFunctionsSuite.scala @@ -29,7 +29,7 @@ class XPathFunctionsSuite extends QueryTest with SharedSparkSession { test("xpath_boolean") { val df = Seq("b").toDF("xml") checkAnswer(df.selectExpr("xpath_boolean(xml, 'a/b')"), Row(true)) - checkAnswer(df.select(xpath_boolean(col("xml"), "a/b")), Row(true)) + checkAnswer(df.select(xpath_boolean(col("xml"), lit("a/b"))), Row(true)) } test("xpath_short, xpath_int, xpath_long") { @@ -42,9 +42,9 @@ class XPathFunctionsSuite extends QueryTest with SharedSparkSession { Row(3.toShort, 3, 3L)) checkAnswer( df.select( - xpath_short(col("xml"), "sum(a/b)"), - xpath_int(col("xml"), "sum(a/b)"), - xpath_long(col("xml"), "sum(a/b)")), + xpath_short(col("xml"), lit("sum(a/b)")), + xpath_int(col("xml"), lit("sum(a/b)")), + xpath_long(col("xml"), lit("sum(a/b)"))), Row(3.toShort, 3, 3L)) } @@ -58,22 +58,22 @@ class XPathFunctionsSuite extends QueryTest with SharedSparkSession { Row(3.1.toFloat, 3.1, 3.1)) checkAnswer( df.select( - xpath_float(col("xml"), "sum(a/b)"), - xpath_double(col("xml"), "sum(a/b)"), - xpath_number(col("xml"), "sum(a/b)")), + xpath_float(col("xml"), lit("sum(a/b)")), + xpath_double(col("xml"), lit("sum(a/b)")), + xpath_number(col("xml"), lit("sum(a/b)"))), Row(3.1.toFloat, 3.1, 3.1)) } test("xpath_string") { val df = Seq("bcc").toDF("xml") checkAnswer(df.selectExpr("xpath_string(xml, 'a/c')"), Row("cc")) - checkAnswer(df.select(xpath_string(col("xml"), "a/c")), Row("cc")) + checkAnswer(df.select(xpath_string(col("xml"), lit("a/c"))), Row("cc")) } test("xpath") { val df = Seq("b1b2b3c1c2").toDF("xml") checkAnswer(df.selectExpr("xpath(xml, 'a/*/text()')"), Row(Seq("b1", "b2", "b3", "c1", "c2"))) - checkAnswer(df.select(xpath(col("xml"), "a/*/text()")), + checkAnswer(df.select(xpath(col("xml"), lit("a/*/text()"))), Row(Seq("b1", "b2", "b3", "c1", "c2"))) } } From ebe6f4f31980468f1a0d5acdad591121bbf437e2 Mon Sep 17 00:00:00 2001 From: panbingkun Date: Tue, 6 Jun 2023 22:01:55 +0800 Subject: [PATCH 3/3] [SPARK-43935][SQL][PYTHON][CONNECT] Add xpath_* functions to Scala and Python --- python/pyspark/sql/functions.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 8e4558dc4bca3..6f700015b1f25 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -5065,7 +5065,8 @@ def xpath_boolean(xml: "ColumnOrName", path: "ColumnOrName") -> Column: @try_remote_functions def xpath_double(xml: "ColumnOrName", path: "ColumnOrName") -> Column: """ - Returns a double value, the value zero if no match is found, or NaN if a match is found but the value is non-numeric. + Returns a double value, the value zero if no match is found, + or NaN if a match is found but the value is non-numeric. Examples -------- @@ -5079,7 +5080,8 @@ def xpath_double(xml: "ColumnOrName", path: "ColumnOrName") -> Column: @try_remote_functions def xpath_number(xml: "ColumnOrName", path: "ColumnOrName") -> Column: """ - Returns a double value, the value zero if no match is found, or NaN if a match is found but the value is non-numeric. + Returns a double value, the value zero if no match is found, + or NaN if a match is found but the value is non-numeric. Examples -------- @@ -5093,7 +5095,8 @@ def xpath_number(xml: "ColumnOrName", path: "ColumnOrName") -> Column: @try_remote_functions def xpath_float(xml: "ColumnOrName", path: "ColumnOrName") -> Column: """ - Returns a float value, the value zero if no match is found, or NaN if a match is found but the value is non-numeric. + Returns a float value, the value zero if no match is found, + or NaN if a match is found but the value is non-numeric. Examples -------- @@ -5107,7 +5110,8 @@ def xpath_float(xml: "ColumnOrName", path: "ColumnOrName") -> Column: @try_remote_functions def xpath_int(xml: "ColumnOrName", path: "ColumnOrName") -> Column: """ - Returns an integer value, or the value zero if no match is found, or a match is found but the value is non-numeric. + Returns an integer value, or the value zero if no match is found, + or a match is found but the value is non-numeric. Examples -------- @@ -5121,7 +5125,8 @@ def xpath_int(xml: "ColumnOrName", path: "ColumnOrName") -> Column: @try_remote_functions def xpath_long(xml: "ColumnOrName", path: "ColumnOrName") -> Column: """ - Returns a long integer value, or the value zero if no match is found, or a match is found but the value is non-numeric. + Returns a long integer value, or the value zero if no match is found, + or a match is found but the value is non-numeric. Examples -------- @@ -5135,7 +5140,8 @@ def xpath_long(xml: "ColumnOrName", path: "ColumnOrName") -> Column: @try_remote_functions def xpath_short(xml: "ColumnOrName", path: "ColumnOrName") -> Column: """ - Returns a short integer value, or the value zero if no match is found, or a match is found but the value is non-numeric. + Returns a short integer value, or the value zero if no match is found, + or a match is found but the value is non-numeric. Examples --------