Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-32793][SQL] Add raise_error function, adds error message parameter to assert_true #29947

Closed
wants to merge 25 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions R/pkg/R/functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -826,6 +826,55 @@ setMethod("xxhash64",
column(jc)
})

#' @details
#' \code{assert_true}: Returns null if the input column is true; throws an exception
#' with the provided error message otherwise.
#'
#' @param errMsg (optional) The error message to be thrown.
#'
#' @rdname column_misc_functions
#' @aliases assert_true assert_true,Column-method
#' @examples
#' \dontrun{
#' tmp <- mutate(df, v1 = assert_true(df$vs < 2),
#' v2 = assert_true(df$vs < 2, "custom error message"),
#' v3 = assert_true(df$vs < 2, df$vs))
#' head(tmp)}
#' @note assert_true since 3.1.0
setMethod("assert_true",
signature(x = "Column"),
function(x, errMsg = NULL) {
jc <- if (is.null(errMsg)) {
callJStatic("org.apache.spark.sql.functions", "assert_true", x@jc)
} else {
if (is.character(errMsg) && length(errMsg) == 1) {
Copy link
Member

@zero323 zero323 Oct 8, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we throw an exception if length(errMsg) != 1? Just in case user does something like this?

> assert_true(column("foo"), c("foo", "bar"))
Error in invokeJava(isStatic = TRUE, className, methodName, ...) : 
  trying to get slot "jc" from an object of a basic class ("character") with no slots

i.e.

           ...
            } else {
              if (is.character(errMsg) {
                stopifnot(length(errMsg) == 1)
           ...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, more checks should be fine.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In practice we make this check anyway, so it is only a question if we do something about it.

errMsg <- lit(errMsg)
}
callJStatic("org.apache.spark.sql.functions", "assert_true", x@jc, errMsg@jc)
}
column(jc)
})

#' @details
#' \code{raise_error}: Throws an exception with the provided error message.
#'
#' @rdname column_misc_functions
#' @aliases raise_error raise_error,characterOrColumn-method
#' @examples
#' \dontrun{
#' tmp <- mutate(df, v1 = raise_error("error message"))
#' head(tmp)}
#' @note raise_error since 3.1.0
setMethod("raise_error",
signature(x = "characterOrColumn"),
function(x) {
if (is.character(x) && length(x) == 1) {
x <- lit(x)
}
jc <- callJStatic("org.apache.spark.sql.functions", "raise_error", x@jc)
column(jc)
})

#' @details
#' \code{dayofmonth}: Extracts the day of the month as an integer from a
#' given date/timestamp/string.
Expand Down
8 changes: 8 additions & 0 deletions R/pkg/R/generics.R
Original file line number Diff line number Diff line change
Expand Up @@ -847,6 +847,10 @@ setGeneric("arrays_zip_with", function(x, y, f) { standardGeneric("arrays_zip_wi
#' @name NULL
setGeneric("ascii", function(x) { standardGeneric("ascii") })

#' @rdname column_misc_functions
#' @name NULL
setGeneric("assert_true", function(x, errMsg = NULL) { standardGeneric("assert_true") })

#' @param x Column to compute on or a GroupedData object.
#' @param ... additional argument(s) when \code{x} is a GroupedData object.
#' @rdname avg
Expand Down Expand Up @@ -1220,6 +1224,10 @@ setGeneric("posexplode_outer", function(x) { standardGeneric("posexplode_outer")
#' @name NULL
setGeneric("quarter", function(x) { standardGeneric("quarter") })

#' @rdname column_misc_functions
#' @name NULL
setGeneric("raise_error", function(x) { standardGeneric("raise_error") })

#' @rdname column_nonaggregate_functions
#' @name NULL
setGeneric("rand", function(seed) { standardGeneric("rand") })
Expand Down
18 changes: 18 additions & 0 deletions R/pkg/tests/fulltests/test_sparkSQL.R
Original file line number Diff line number Diff line change
Expand Up @@ -3928,6 +3928,24 @@ test_that("catalog APIs, listTables, listColumns, listFunctions", {
dropTempView("cars")
})

test_that("assert_true, raise_error", {
df <- read.json(jsonPath)
filtered <- filter(df, "age < 20")

expect_equal(collect(select(filtered, assert_true(filtered$age < 20)))$age, c(NULL))
expect_equal(collect(select(filtered, assert_true(filtered$age < 20, "error message")))$age,
c(NULL))
expect_equal(collect(select(filtered, assert_true(filtered$age < 20, filtered$name)))$age,
c(NULL))
expect_error(collect(select(df, assert_true(df$age < 20))), "is not true!")
expect_error(collect(select(df, assert_true(df$age < 20, "error message"))),
"error message")
expect_error(collect(select(df, assert_true(df$age < 20, df$name))), "Michael")

expect_error(collect(select(filtered, raise_error("error message"))), "error message")
expect_error(collect(select(filtered, raise_error(filtered$name))), "Justin")
})

compare_list <- function(list1, list2) {
# get testthat to show the diff by first making the 2 lists equal in length
expect_equal(length(list1), length(list2))
Expand Down
2 changes: 2 additions & 0 deletions python/docs/source/reference/pyspark.sql.rst
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ Functions
asc_nulls_last
ascii
asin
assert_true
atan
atan2
avg
Expand Down Expand Up @@ -420,6 +421,7 @@ Functions
pow
quarter
radians
raise_error
rand
randn
rank
Expand Down
55 changes: 53 additions & 2 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1592,6 +1592,57 @@ def xxhash64(*cols):
return Column(jc)


karenfeng marked this conversation as resolved.
Show resolved Hide resolved
@since(3.1)
def assert_true(col, errMsg=None):
"""
Returns null if the input column is true; throws an exception with the provided error message
otherwise.

>>> df = spark.createDataFrame([(0,1)], ['a', 'b'])
>>> df.select(assert_true(df.a < df.b).alias('r')).collect()
[Row(r=None)]
>>> df = spark.createDataFrame([(0,1)], ['a', 'b'])
>>> df.select(assert_true(df.a < df.b, df.a).alias('r')).collect()
[Row(r=None)]
>>> df = spark.createDataFrame([(0,1)], ['a', 'b'])
>>> df.select(assert_true(df.a < df.b, 'error').alias('r')).collect()
[Row(r=None)]
"""
sc = SparkContext._active_spark_context
if errMsg is None:
return Column(sc._jvm.functions.assert_true(_to_java_column(col)))
if not isinstance(errMsg, (str, Column)):
raise TypeError(
"errMsg should be a Column or a str, got {}".format(type(errMsg))
)

errMsg = (
_create_column_from_literal(errMsg)
if isinstance(errMsg, str)
else _to_java_column(errMsg)
)
return Column(sc._jvm.functions.assert_true(_to_java_column(col), errMsg))


@since(3.1)
def raise_error(errMsg):
"""
Throws an exception with the provided error message.
"""
if not isinstance(errMsg, (str, Column)):
raise TypeError(
"errMsg should be a Column or a str, got {}".format(type(errMsg))
)

sc = SparkContext._active_spark_context
errMsg = (
_create_column_from_literal(errMsg)
if isinstance(errMsg, str)
else _to_java_column(errMsg)
)
return Column(sc._jvm.functions.raise_error(errMsg))


# ---------------------- String/Binary functions ------------------------------

_string_functions = {
Expand Down Expand Up @@ -3448,14 +3499,14 @@ def bucket(numBuckets, col):
... ).createOrReplace()

.. warning::
This function can be used only in combinatiion with
This function can be used only in combination with
:py:meth:`~pyspark.sql.readwriter.DataFrameWriterV2.partitionedBy`
method of the `DataFrameWriterV2`.

"""
if not isinstance(numBuckets, (int, Column)):
raise TypeError(
"numBuckets should be a Column or and int, got {}".format(type(numBuckets))
"numBuckets should be a Column or an int, got {}".format(type(numBuckets))
)

sc = SparkContext._active_spark_context
Expand Down
2 changes: 2 additions & 0 deletions python/pyspark/sql/functions.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ def sha1(col: ColumnOrName) -> Column: ...
def sha2(col: ColumnOrName, numBits: int) -> Column: ...
def hash(*cols: ColumnOrName) -> Column: ...
def xxhash64(*cols: ColumnOrName) -> Column: ...
def assert_true(col: ColumnOrName, errMsg: Union[Column, str] = ...): ...
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two small notes (sorry for being late):

  • I think we should annotate return type for assert_true - it will type check because of implicit Any, but I think it is better to avoid such cases

    def assert_true(col: ColumnOrName, errMsg: Union[Column, str] = ...) -> Column: ...
  • For def raise_error I'd use NoReturn:

    from typing import NoReturn
    
    def raise_error(errMsg: Union[Column, str]) -> NoReturn: ...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@karenfeng Could you fix them above in followup?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might indicate intention here, though technically speaking it's still a Column, so

def raise_error(errMsg: Union[Column, str]) -> Column: ...

is still correct (and literal one). Do you have any thoughts about it @HyukjinKwon?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think doing Column is fine.

def raise_error(errMsg: Union[Column, str]): ...
def concat(*cols: ColumnOrName) -> Column: ...
def concat_ws(sep: str, *cols: ColumnOrName) -> Column: ...
def decode(col: ColumnOrName, charset: str) -> Column: ...
Expand Down
50 changes: 50 additions & 0 deletions python/pyspark/sql/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from itertools import chain
import re

from py4j.protocol import Py4JJavaError
from pyspark.sql import Row, Window
from pyspark.sql.functions import udf, input_file_name, col, percentile_approx, lit
from pyspark.testing.sqlutils import ReusedSQLTestCase
Expand Down Expand Up @@ -524,6 +525,55 @@ def test_datetime_functions(self):
parse_result = df.select(functions.to_date(functions.col("dateCol"))).first()
self.assertEquals(date(2017, 1, 22), parse_result['to_date(dateCol)'])

def test_assert_true(self):
from pyspark.sql.functions import assert_true

df = self.spark.range(3)

self.assertEquals(
df.select(assert_true(df.id < 3)).toDF("val").collect(),
[Row(val=None), Row(val=None), Row(val=None)],
)

with self.assertRaises(Py4JJavaError) as cm:
df.select(assert_true(df.id < 2, 'too big')).toDF("val").collect()
self.assertIn("java.lang.RuntimeException", str(cm.exception))
self.assertIn("too big", str(cm.exception))

with self.assertRaises(Py4JJavaError) as cm:
df.select(assert_true(df.id < 2, df.id * 1e6)).toDF("val").collect()
self.assertIn("java.lang.RuntimeException", str(cm.exception))
self.assertIn("2000000", str(cm.exception))

with self.assertRaises(TypeError) as cm:
df.select(assert_true(df.id < 2, 5))
self.assertEquals(
"errMsg should be a Column or a str, got <class 'int'>",
str(cm.exception)
)

def test_raise_error(self):
from pyspark.sql.functions import raise_error

df = self.spark.createDataFrame([Row(id="foobar")])

with self.assertRaises(Py4JJavaError) as cm:
df.select(raise_error(df.id)).collect()
self.assertIn("java.lang.RuntimeException", str(cm.exception))
self.assertIn("foobar", str(cm.exception))

with self.assertRaises(Py4JJavaError) as cm:
df.select(raise_error("barfoo")).collect()
self.assertIn("java.lang.RuntimeException", str(cm.exception))
self.assertIn("barfoo", str(cm.exception))

with self.assertRaises(TypeError) as cm:
df.select(raise_error(None))
self.assertEquals(
"errMsg should be a Column or a str, got <class 'NoneType'>",
str(cm.exception)
)


if __name__ == "__main__":
import unittest
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,7 @@ object FunctionRegistry {

// misc functions
expression[AssertTrue]("assert_true"),
expression[RaiseError]("raise_error"),
expression[Crc32]("crc32"),
expression[Md5]("md5"),
expression[Uuid]("uuid"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,51 +53,81 @@ case class PrintToStderr(child: Expression) extends UnaryExpression {
}

/**
* A function throws an exception if 'condition' is not true.
* Throw with the result of an expression (used for debugging).
*/
@ExpressionDescription(
usage = "_FUNC_(expr) - Throws an exception if `expr` is not true.",
usage = "_FUNC_(expr) - Throws an exception with `expr`.",
examples = """
Examples:
> SELECT _FUNC_(0 < 1);
NULL
> SELECT _FUNC_('custom error message');
java.lang.RuntimeException
custom error message
""",
since = "2.0.0")
case class AssertTrue(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
since = "3.1.0")
case class RaiseError(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
viirya marked this conversation as resolved.
Show resolved Hide resolved

override def foldable: Boolean = false
override def nullable: Boolean = true

override def inputTypes: Seq[DataType] = Seq(BooleanType)

override def dataType: DataType = NullType
override def inputTypes: Seq[AbstractDataType] = Seq(StringType)

override def prettyName: String = "assert_true"
override def prettyName: String = "raise_error"

private val errMsg = s"'${child.simpleString(SQLConf.get.maxToStringFields)}' is not true!"

override def eval(input: InternalRow) : Any = {
val v = child.eval(input)
if (v == null || java.lang.Boolean.FALSE.equals(v)) {
throw new RuntimeException(errMsg)
} else {
null
override def eval(input: InternalRow): Any = {
val value = child.eval(input)
if (value == null) {
throw new RuntimeException()
}
throw new RuntimeException(value.toString)
}

// if (true) is to avoid codegen compilation exception that statement is unreachable
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val eval = child.genCode(ctx)
ExprCode(
code = code"""${eval.code}
|if (true) {
| if (${eval.isNull}) {
| throw new RuntimeException();
| }
| throw new RuntimeException(${eval.value}.toString());
|}""".stripMargin,
isNull = TrueLiteral,
viirya marked this conversation as resolved.
Show resolved Hide resolved
value = JavaCode.defaultLiteral(dataType)
)
}
}

// Use unnamed reference that doesn't create a local field here to reduce the number of fields
// because errMsgField is used only when the value is null or false.
val errMsgField = ctx.addReferenceObj("errMsg", errMsg)
ExprCode(code = code"""${eval.code}
|if (${eval.isNull} || !${eval.value}) {
| throw new RuntimeException($errMsgField);
|}""".stripMargin, isNull = TrueLiteral,
value = JavaCode.defaultLiteral(dataType))
/**
* A function that throws an exception if 'condition' is not true.
*/
@ExpressionDescription(
usage = "_FUNC_(expr) - Throws an exception if `expr` is not true.",
examples = """
Examples:
> SELECT _FUNC_(0 < 1);
NULL
""",
since = "2.0.0")
case class AssertTrue(left: Expression, right: Expression, child: Expression)
extends RuntimeReplaceable {

override def prettyName: String = "assert_true"

def this(left: Expression, right: Expression) = {
this(left, right, If(left, Literal(null), RaiseError(right)))
}

override def sql: String = s"assert_true(${child.sql})"
def this(left: Expression) = {
this(left, Literal(s"'${left.simpleString(SQLConf.get.maxToStringFields)}' is not true!"))
}

override def flatArguments: Iterator[Any] = Iterator(left, right)
override def exprsReplaced: Seq[Expression] = Seq(left, right)
}

object AssertTrue {
def apply(left: Expression): AssertTrue = new AssertTrue(left)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
}

test("SPARK-17160: field names are properly escaped by AssertTrue") {
GenerateUnsafeProjection.generate(AssertTrue(Cast(Literal("\""), BooleanType)) :: Nil)
GenerateUnsafeProjection.generate(AssertTrue(Cast(Literal("\""), BooleanType)).child :: Nil)
}

test("should not apply common subexpression elimination on conditional expressions") {
Expand Down
Loading