diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 2fadf20da491c..a9cca4bf6f6fc 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -230,6 +230,7 @@ exportMethods("%<=>%", "asc", "ascii", "asin", + "assert_true", "atan", "atan2", "avg", @@ -361,6 +362,7 @@ exportMethods("%<=>%", "posexplode_outer", "quarter", "radians", + "raise_error", "rand", "randn", "rank", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index ce384a64bccaf..bcd798a8c31e2 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -847,7 +847,8 @@ setMethod("assert_true", jc <- if (is.null(errMsg)) { callJStatic("org.apache.spark.sql.functions", "assert_true", x@jc) } else { - if (is.character(errMsg) && length(errMsg) == 1) { + if (is.character(errMsg)) { + stopifnot(length(errMsg) == 1) errMsg <- lit(errMsg) } callJStatic("org.apache.spark.sql.functions", "assert_true", x@jc, errMsg@jc) @@ -868,7 +869,8 @@ setMethod("assert_true", setMethod("raise_error", signature(x = "characterOrColumn"), function(x) { - if (is.character(x) && length(x) == 1) { + if (is.character(x)) { + stopifnot(length(x) == 1) x <- lit(x) } jc <- callJStatic("org.apache.spark.sql.functions", "raise_error", x@jc) diff --git a/python/pyspark/sql/functions.pyi b/python/pyspark/sql/functions.pyi index 6249bca5cef68..779a29c086d5a 100644 --- a/python/pyspark/sql/functions.pyi +++ b/python/pyspark/sql/functions.pyi @@ -137,8 +137,8 @@ def sha1(col: ColumnOrName) -> Column: ... def sha2(col: ColumnOrName, numBits: int) -> Column: ... def hash(*cols: ColumnOrName) -> Column: ... def xxhash64(*cols: ColumnOrName) -> Column: ... -def assert_true(col: ColumnOrName, errMsg: Union[Column, str] = ...): ... -def raise_error(errMsg: Union[Column, str]): ... +def assert_true(col: ColumnOrName, errMsg: Union[Column, str] = ...) -> Column: ... +def raise_error(errMsg: Union[Column, str]) -> Column: ... def concat(*cols: ColumnOrName) -> Column: ... def concat_ws(sep: str, *cols: ColumnOrName) -> Column: ... def decode(col: ColumnOrName, charset: str) -> Column: ...