Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-43943][SQL][PYTHON][CONNECT] Add SQL math functions to Scala and Python #41435

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -854,6 +854,14 @@ object functions {
*/
def skewness(columnName: String): Column = skewness(Column(columnName))

/**
* Aggregate function: alias for `stddev_samp`.
*
* @group agg_funcs
* @since 3.5.0
*/
def std(e: Column): Column = stddev(e)

/**
* Aggregate function: alias for `stddev_samp`.
*
Expand Down Expand Up @@ -1959,6 +1967,22 @@ object functions {
*/
def ceil(columnName: String): Column = ceil(Column(columnName))

/**
* Computes the ceiling of the given value of `e` to `scale` decimal places.
*
* @group math_funcs
* @since 3.5.0
*/
def ceiling(e: Column, scale: Column): Column = ceil(e, scale)

/**
* Computes the ceiling of the given value of `e` to 0 decimal places.
*
* @group math_funcs
* @since 3.5.0
*/
def ceiling(e: Column): Column = ceil(e)

/**
* Convert a number in a string column from one base to another.
*
Expand Down Expand Up @@ -2034,6 +2058,14 @@ object functions {
*/
def csc(e: Column): Column = Column.fn("csc", e)

/**
* Returns Euler's number.
*
* @group math_funcs
* @since 3.5.0
*/
def e(): Column = Column.fn("e")

/**
* Computes the exponential of the given value.
*
Expand Down Expand Up @@ -2222,6 +2254,14 @@ object functions {
def least(columnName: String, columnNames: String*): Column =
least((columnName +: columnNames).map(Column.apply): _*)

/**
* Computes the natural logarithm of the given value.
*
* @group math_funcs
* @since 3.5.0
*/
def ln(e: Column): Column = log(e)

/**
* Computes the natural logarithm of the given value.
*
Expand Down Expand Up @@ -2302,6 +2342,30 @@ object functions {
*/
def log2(columnName: String): Column = log2(Column(columnName))

/**
* Returns the negated value.
*
* @group math_funcs
* @since 3.5.0
*/
def negative(e: Column): Column = Column.fn("negative", e)

/**
* Returns Pi.
*
* @group math_funcs
* @since 3.5.0
*/
def pi(): Column = Column.fn("pi")

/**
* Returns the value.
*
* @group math_funcs
* @since 3.5.0
*/
def positive(e: Column): Column = Column.fn("positive", e)

/**
* Returns the value of the first argument raised to the power of the second argument.
*
Expand Down Expand Up @@ -2366,6 +2430,14 @@ object functions {
*/
def pow(l: Double, rightName: String): Column = pow(l, Column(rightName))

/**
* Returns the value of the first argument raised to the power of the second argument.
*
* @group math_funcs
* @since 3.5.0
*/
def power(l: Column, r: Column): Column = pow(l, r)

/**
* Returns the positive value of dividend mod divisor.
*
Expand Down Expand Up @@ -2495,6 +2567,14 @@ object functions {
def shiftrightunsigned(e: Column, numBits: Int): Column =
Column.fn("shiftrightunsigned", e, lit(numBits))

/**
* Computes the signum of the given value.
*
* @group math_funcs
* @since 3.5.0
*/
def sign(e: Column): Column = signum(e)

/**
* Computes the signum of the given value.
*
Expand Down Expand Up @@ -2683,6 +2763,27 @@ object functions {
*/
def radians(columnName: String): Column = radians(Column(columnName))

/**
* Returns the bucket number into which the value of this expression would fall after being
* evaluated. Note that input arguments must follow conditions listed below; otherwise, the
* method will return null.
*
* @param v
* value to compute a bucket number in the histogram
* @param min
* minimum value of the histogram
* @param max
* maximum value of the histogram
* @param numBucket
* the number of buckets
* @return
* the bucket number into which the value would fall after being evaluated
* @group math_funcs
* @since 3.5.0
*/
def width_bucket(v: Column, min: Column, max: Column, numBucket: Column): Column =
Column.fn("width_bucket", v, min, max, numBucket)

//////////////////////////////////////////////////////////////////////////////////////////////
// Misc functions
//////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1002,6 +1002,10 @@ class PlanGenerationTestSuite
fn.stddev("a")
}

functionTest("std") {
fn.std(fn.col("a"))
}

functionTest("stddev_samp") {
fn.stddev_samp("a")
}
Expand Down Expand Up @@ -1178,6 +1182,14 @@ class PlanGenerationTestSuite
fn.ceil(fn.col("b"), lit(2))
}

functionTest("ceiling") {
fn.ceiling(fn.col("b"))
}

functionTest("ceiling scale") {
fn.ceiling(fn.col("b"), lit(2))
}

functionTest("conv") {
fn.conv(fn.col("b"), 10, 16)
}
Expand All @@ -1198,6 +1210,10 @@ class PlanGenerationTestSuite
fn.csc(fn.col("b"))
}

functionTest("e") {
fn.e()
}

functionTest("exp") {
fn.exp("b")
}
Expand Down Expand Up @@ -1242,6 +1258,10 @@ class PlanGenerationTestSuite
fn.log("b")
}

functionTest("ln") {
fn.ln(fn.col("b"))
}

functionTest("log with base") {
fn.log(2, "b")
}
Expand All @@ -1258,10 +1278,26 @@ class PlanGenerationTestSuite
fn.log2("a")
}

functionTest("negative") {
fn.negative(fn.col("a"))
}

functionTest("pi") {
fn.pi()
}

functionTest("positive") {
fn.positive(fn.col("a"))
}

functionTest("pow") {
fn.pow("a", "b")
}

functionTest("power") {
fn.power(fn.col("a"), fn.col("b"))
}

functionTest("pmod") {
fn.pmod(fn.col("a"), fn.lit(10))
}
Expand Down Expand Up @@ -1298,6 +1334,10 @@ class PlanGenerationTestSuite
fn.signum("b")
}

functionTest("sign") {
fn.sign(fn.col("b"))
}

functionTest("sin") {
fn.sin("b")
}
Expand Down Expand Up @@ -2128,6 +2168,10 @@ class PlanGenerationTestSuite
simple.groupBy(Column("id")).pivot("a").agg(functions.count(Column("b")))
}

test("width_bucket") {
simple.select(fn.width_bucket(fn.col("b"), fn.col("b"), fn.col("b"), fn.col("a")))
}

test("test broadcast") {
left.join(fn.broadcast(right), "id")
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Project [CEIL(b#0) AS CEIL(b)#0L]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Project [ceil(cast(b#0 as decimal(30,15)), 2) AS ceil(b, 2)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Project [E() AS E()#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Project [LOG(E(), b#0) AS LOG(E(), b)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Project [-a#0 AS negative(a)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Project [PI() AS PI()#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Project [positive(a#0) AS (+ a)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Project [POWER(cast(a#0 as double), b#0) AS POWER(a, b)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Project [SIGNUM(b#0) AS SIGNUM(b)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Aggregate [stddev(cast(a#0 as double)) AS stddev(a)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Project [width_bucket(b#0, b#0, b#0, cast(a#0 as bigint)) AS width_bucket(b, b, b, a)#0L]
+- LocalRelation <empty>, [id#0L, a#0, b#0]
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
{
"common": {
"planId": "1"
},
"project": {
"input": {
"common": {
"planId": "0"
},
"localRelation": {
"schema": "struct\u003cid:bigint,a:int,b:double,d:struct\u003cid:bigint,a:int,b:double\u003e,e:array\u003cint\u003e,f:map\u003cstring,struct\u003cid:bigint,a:int,b:double\u003e\u003e,g:string\u003e"
}
},
"expressions": [{
"unresolvedFunction": {
"functionName": "ceil",
"arguments": [{
"unresolvedAttribute": {
"unparsedIdentifier": "b"
}
}]
}
}]
}
}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
{
"common": {
"planId": "1"
},
"project": {
"input": {
"common": {
"planId": "0"
},
"localRelation": {
"schema": "struct\u003cid:bigint,a:int,b:double,d:struct\u003cid:bigint,a:int,b:double\u003e,e:array\u003cint\u003e,f:map\u003cstring,struct\u003cid:bigint,a:int,b:double\u003e\u003e,g:string\u003e"
}
},
"expressions": [{
"unresolvedFunction": {
"functionName": "ceil",
"arguments": [{
"unresolvedAttribute": {
"unparsedIdentifier": "b"
}
}, {
"literal": {
"integer": 2
}
}]
}
}]
}
}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{
"common": {
"planId": "1"
},
"project": {
"input": {
"common": {
"planId": "0"
},
"localRelation": {
"schema": "struct\u003cid:bigint,a:int,b:double,d:struct\u003cid:bigint,a:int,b:double\u003e,e:array\u003cint\u003e,f:map\u003cstring,struct\u003cid:bigint,a:int,b:double\u003e\u003e,g:string\u003e"
}
},
"expressions": [{
"unresolvedFunction": {
"functionName": "e"
}
}]
}
}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
{
"common": {
"planId": "1"
},
"project": {
"input": {
"common": {
"planId": "0"
},
"localRelation": {
"schema": "struct\u003cid:bigint,a:int,b:double,d:struct\u003cid:bigint,a:int,b:double\u003e,e:array\u003cint\u003e,f:map\u003cstring,struct\u003cid:bigint,a:int,b:double\u003e\u003e,g:string\u003e"
}
},
"expressions": [{
"unresolvedFunction": {
"functionName": "log",
"arguments": [{
"unresolvedAttribute": {
"unparsedIdentifier": "b"
}
}]
}
}]
}
}
Binary file not shown.
Loading