Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-30607][SQL][PYSPARK][SPARKR] Add overlay wrappers for SparkR and PySpark #27325

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions R/pkg/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ exportMethods("%<=>%",
"ntile",
"otherwise",
"over",
"overlay",
"percent_rank",
"pmod",
"posexplode",
Expand Down
37 changes: 35 additions & 2 deletions R/pkg/R/functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,14 @@ NULL
#' format to. See 'Details'.
#' }
#' @param y Column to compute on.
#' @param pos In \itemize{
#' \item \code{locate}: a start position of search.
#' \item \code{overlay}: a start postiton for replacement.
#' }
#' @param len In \itemize{
#' \item \code{lpad} the maximum length of each output result.
#' \item \code{overlay} a number of bytes to replace.
#' }
#' @param ... additional Columns.
#' @name column_string_functions
#' @rdname column_string_functions
Expand Down Expand Up @@ -1319,6 +1327,33 @@ setMethod("negate",
column(jc)
})

#' @details
#' \code{overlay}:
HyukjinKwon marked this conversation as resolved.
Show resolved Hide resolved
#'
#' @param replace a Column with replacement.
#'
#' @rdname column_string_functions
#' @aliases overlay overlay,Column-method,numericOrColumn-method
#' @note overlay since 3.0.0
setMethod("overlay",
signature(x = "Column", replace = "Column", pos = "numericOrColumn"),
function(x, replace, pos, len = -1) {
if (is.numeric(pos)) {
pos <- lit(as.integer(pos))
}

if (is.numeric(len)) {
len <- lit(as.integer(len))
}

jc <- callJStatic(
"org.apache.spark.sql.functions", "overlay",
x@jc, replace@jc, pos@jc, len@jc
)

column(jc)
})

#' @details
#' \code{quarter}: Extracts the quarter as an integer from a given date/timestamp/string.
#'
Expand Down Expand Up @@ -2819,7 +2854,6 @@ setMethod("window", signature(x = "Column"),
#'
#' @param substr a character string to be matched.
#' @param str a Column where matches are sought for each entry.
#' @param pos start position of search.
#' @rdname column_string_functions
#' @aliases locate locate,character,Column-method
#' @note locate since 1.5.0
Expand All @@ -2834,7 +2868,6 @@ setMethod("locate", signature(substr = "character", str = "Column"),
#' @details
#' \code{lpad}: Left-padded with pad to a length of len.
#'
#' @param len maximum length of each output result.
#' @param pad a character string to be padded with.
#' @rdname column_string_functions
#' @aliases lpad lpad,Column,numeric,character-method
Expand Down
4 changes: 4 additions & 0 deletions R/pkg/R/generics.R
Original file line number Diff line number Diff line change
Expand Up @@ -1149,6 +1149,10 @@ setGeneric("ntile", function(x) { standardGeneric("ntile") })
#' @name NULL
setGeneric("n_distinct", function(x, ...) { standardGeneric("n_distinct") })

#' @rdname column_string_functions
#' @name NULL
setGeneric("overlay", function(x, replace, pos, ...) { standardGeneric("overlay") })

#' @rdname column_window_functions
#' @name NULL
setGeneric("percent_rank", function(x = "missing") { standardGeneric("percent_rank") })
Expand Down
2 changes: 2 additions & 0 deletions R/pkg/tests/fulltests/test_sparkSQL.R
Original file line number Diff line number Diff line change
Expand Up @@ -1405,6 +1405,8 @@ test_that("column functions", {
trunc(c, "month") + trunc(c, "mon") + trunc(c, "mm")
c24 <- date_trunc("hour", c) + date_trunc("minute", c) + date_trunc("week", c) +
date_trunc("quarter", c) + current_date() + current_timestamp()
c25 <- overlay(c1, c2, c3, c3) + overlay(c1, c2, c3) + overlay(c1, c2, 1) +
overlay(c1, c2, 3, 4)

# Test if base::is.nan() is exposed
expect_equal(is.nan(c("a", "b")), c(FALSE, FALSE))
Expand Down
34 changes: 34 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1600,6 +1600,40 @@ def instr(str, substr):
return Column(sc._jvm.functions.instr(_to_java_column(str), substr))


@since(3.0)
def overlay(src, replace, pos, len=-1):
"""
Overlay the specified portion of `src` with `replace`,
starting from byte position `pos` of `src` and proceeding for `len` bytes.

>>> df = spark.createDataFrame([("SPARK_SQL", "CORE")], ("x", "y"))
>>> df.select(overlay("x", "y", 7).alias("overlayed")).show()
+----------+
| overlayed|
+----------+
|SPARK_CORE|
+----------+
"""
if not isinstance(pos, (int, str, Column)):
raise TypeError(
"pos should be an integer or a Column / column name, got {}".format(type(pos)))
if len is not None and not isinstance(len, (int, str, Column)):
raise TypeError(
"len should be an integer or a Column / column name, got {}".format(type(len)))

pos = _create_column_from_literal(pos) if isinstance(pos, int) else _to_java_column(pos)
len = _create_column_from_literal(len) if isinstance(len, int) else _to_java_column(len)

sc = SparkContext._active_spark_context

return Column(sc._jvm.functions.overlay(
_to_java_column(src),
_to_java_column(replace),
pos,
len
))


@since(1.5)
@ignore_unicode_prefix
def substring(str, pos, len):
Expand Down
27 changes: 27 additions & 0 deletions python/pyspark/sql/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,33 @@ def test_input_file_name_udf(self):
file_name = df.collect()[0].file
self.assertTrue("python/test_support/hello/hello.txt" in file_name)

def test_overlay(self):
from pyspark.sql.functions import col, lit, overlay
from itertools import chain
import re

actual = list(chain.from_iterable([
re.findall("(overlay\\(.*\\))", str(x)) for x in [
overlay(col("foo"), col("bar"), 1),
overlay("x", "y", 3),
overlay(col("x"), col("y"), 1, 3),
overlay("x", "y", 2, 5),
overlay("x", "y", lit(11)),
overlay("x", "y", lit(2), lit(5)),
]
]))

expected = [
"overlay(foo, bar, 1, -1)",
"overlay(x, y, 3, -1)",
"overlay(x, y, 1, 3)",
"overlay(x, y, 2, 5)",
"overlay(x, y, 11, -1)",
"overlay(x, y, 2, 5)",
]

self.assertListEqual(actual, expected)


if __name__ == "__main__":
import unittest
Expand Down