diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index 7cd74ecd5..74d70d6ee 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -277,6 +277,25 @@ def _iter( return list(_iter(exprs)) +def _to_raw_literal_expr(value: Expr | Any) -> expr_internal.Expr: + """Convert an expression or Python literal to its raw variant. + + Args: + value: Candidate expression or Python literal value. + + Returns: + The internal :class:`~datafusion._internal.expr.Expr` representation. + + Examples: + >>> expr = Expr(_to_raw_literal_expr(1)) + >>> isinstance(expr, Expr) + True + """ + if isinstance(value, Expr): + return value.expr + return Expr.literal(value).expr + + def _to_raw_expr(value: Expr | str) -> expr_internal.Expr: """Convert a Python expression or column name to its raw variant. diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index 841cd9c0b..7794de11d 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -29,6 +29,7 @@ Expr, SortExpr, SortKey, + _to_raw_literal_expr, expr_list_to_raw_expr_list, sort_list_to_raw_sort_list, sort_or_default, @@ -1440,7 +1441,9 @@ def radians(arg: Expr) -> Expr: return Expr(f.radians(arg.expr)) -def regexp_like(string: Expr, regex: Expr, flags: Expr | None = None) -> Expr: +def regexp_like( + string: Expr, regex: Expr | str, flags: Expr | str | None = None +) -> Expr: r"""Find if any regular expression (regex) matches exist. Tests a string using a regular expression returning true if at least one match, @@ -1451,8 +1454,7 @@ def regexp_like(string: Expr, regex: Expr, flags: Expr | None = None) -> Expr: >>> df = ctx.from_pydict({"a": ["hello123"]}) >>> result = df.select( ... dfn.functions.regexp_like( - ... dfn.col("a"), dfn.lit("\\d+") - ... ).alias("m") + ... dfn.col("a"), "\\d+").alias("m") ... ) >>> result.collect_column("m")[0].as_py() True @@ -1461,19 +1463,20 @@ def regexp_like(string: Expr, regex: Expr, flags: Expr | None = None) -> Expr: >>> result = df.select( ... dfn.functions.regexp_like( - ... dfn.col("a"), dfn.lit("HELLO"), - ... flags=dfn.lit("i"), + ... dfn.col("a"), "HELLO", + ... flags="i", ... ).alias("m") ... ) >>> result.collect_column("m")[0].as_py() True """ - if flags is not None: - flags = flags.expr - return Expr(f.regexp_like(string.expr, regex.expr, flags)) + flags = _to_raw_literal_expr(flags) if flags is not None else None + return Expr(f.regexp_like(string.expr, _to_raw_literal_expr(regex), flags)) -def regexp_match(string: Expr, regex: Expr, flags: Expr | None = None) -> Expr: +def regexp_match( + string: Expr, regex: Expr | str, flags: Expr | str | None = None +) -> Expr: r"""Perform regular expression (regex) matching. Returns an array with each element containing the leftmost-first match of the @@ -1484,8 +1487,7 @@ def regexp_match(string: Expr, regex: Expr, flags: Expr | None = None) -> Expr: >>> df = ctx.from_pydict({"a": ["hello 42 world"]}) >>> result = df.select( ... dfn.functions.regexp_match( - ... dfn.col("a"), dfn.lit("(\\d+)") - ... ).alias("m") + ... dfn.col("a"), "(\\d+)").alias("m") ... ) >>> result.collect_column("m")[0].as_py() ['42'] @@ -1494,20 +1496,22 @@ def regexp_match(string: Expr, regex: Expr, flags: Expr | None = None) -> Expr: >>> result = df.select( ... dfn.functions.regexp_match( - ... dfn.col("a"), dfn.lit("(HELLO)"), - ... flags=dfn.lit("i"), + ... dfn.col("a"), "(HELLO)", + ... flags="i", ... ).alias("m") ... ) >>> result.collect_column("m")[0].as_py() ['hello'] """ - if flags is not None: - flags = flags.expr - return Expr(f.regexp_match(string.expr, regex.expr, flags)) + flags = _to_raw_literal_expr(flags) if flags is not None else None + return Expr(f.regexp_match(string.expr, _to_raw_literal_expr(regex), flags)) def regexp_replace( - string: Expr, pattern: Expr, replacement: Expr, flags: Expr | None = None + string: Expr, + pattern: Expr | str, + replacement: Expr | str, + flags: Expr | str | None = None, ) -> Expr: r"""Replaces substring(s) matching a PCRE-like regular expression. @@ -1522,8 +1526,8 @@ def regexp_replace( >>> df = ctx.from_pydict({"a": ["hello 42"]}) >>> result = df.select( ... dfn.functions.regexp_replace( - ... dfn.col("a"), dfn.lit("\\d+"), - ... dfn.lit("XX") + ... dfn.col("a"), "\\d+", + ... "XX" ... ).alias("r") ... ) >>> result.collect_column("r")[0].as_py() @@ -1534,20 +1538,24 @@ def regexp_replace( >>> df = ctx.from_pydict({"a": ["a1 b2 c3"]}) >>> result = df.select( ... dfn.functions.regexp_replace( - ... dfn.col("a"), dfn.lit("\\d+"), - ... dfn.lit("X"), flags=dfn.lit("g"), + ... dfn.col("a"), "\\d+", + ... "X", flags="g", ... ).alias("r") ... ) >>> result.collect_column("r")[0].as_py() 'aX bX cX' """ - if flags is not None: - flags = flags.expr - return Expr(f.regexp_replace(string.expr, pattern.expr, replacement.expr, flags)) + flags = _to_raw_literal_expr(flags) if flags is not None else None + pattern = _to_raw_literal_expr(pattern) + replacement = _to_raw_literal_expr(replacement) + return Expr(f.regexp_replace(string.expr, pattern, replacement, flags)) def regexp_count( - string: Expr, pattern: Expr, start: Expr | None = None, flags: Expr | None = None + string: Expr, + pattern: Expr | str, + start: Expr | int | None = None, + flags: Expr | str | None = None, ) -> Expr: """Returns the number of matches in a string. @@ -1559,8 +1567,7 @@ def regexp_count( >>> df = ctx.from_pydict({"a": ["abcabc"]}) >>> result = df.select( ... dfn.functions.regexp_count( - ... dfn.col("a"), dfn.lit("abc") - ... ).alias("c")) + ... dfn.col("a"), "abc").alias("c")) >>> result.collect_column("c")[0].as_py() 2 @@ -1569,25 +1576,33 @@ def regexp_count( >>> result = df.select( ... dfn.functions.regexp_count( - ... dfn.col("a"), dfn.lit("ABC"), - ... start=dfn.lit(4), flags=dfn.lit("i"), + ... dfn.col("a"), "ABC", + ... start=4, flags="i", ... ).alias("c")) >>> result.collect_column("c")[0].as_py() 1 """ - if flags is not None: - flags = flags.expr - start = start.expr if start is not None else start - return Expr(f.regexp_count(string.expr, pattern.expr, start, flags)) + pattern = _to_raw_literal_expr(pattern) + flags = _to_raw_literal_expr(flags) if flags is not None else None + start = _to_raw_literal_expr(start) if start is not None else None + + # If Python callers pass only flags, supply the default start=1. + # because Datafusion accepts: + # regexp_count(string, pattern, start) + # regexp_count(string, pattern, start, flags) + if start is None and flags is not None: + start = _to_raw_literal_expr(1) + + return Expr(f.regexp_count(string.expr, pattern, start, flags)) def regexp_instr( values: Expr, - regex: Expr, - start: Expr | None = None, - n: Expr | None = None, - flags: Expr | None = None, - sub_expr: Expr | None = None, + regex: Expr | str, + start: Expr | int | None = None, + n: Expr | int | None = None, + flags: Expr | str | None = None, + sub_expr: Expr | int | None = None, ) -> Expr: r"""Returns the position of a regular expression match in a string. @@ -1604,8 +1619,7 @@ def regexp_instr( >>> df = ctx.from_pydict({"a": ["hello 42 world"]}) >>> result = df.select( ... dfn.functions.regexp_instr( - ... dfn.col("a"), dfn.lit("\\d+") - ... ).alias("pos") + ... dfn.col("a"), "\\d+").alias("pos") ... ) >>> result.collect_column("pos")[0].as_py() 7 @@ -1616,9 +1630,9 @@ def regexp_instr( >>> df = ctx.from_pydict({"a": ["abc ABC abc"]}) >>> result = df.select( ... dfn.functions.regexp_instr( - ... dfn.col("a"), dfn.lit("abc"), - ... start=dfn.lit(2), n=dfn.lit(1), - ... flags=dfn.lit("i"), + ... dfn.col("a"), "abc", + ... start=2, n=1, + ... flags="i", ... ).alias("pos") ... ) >>> result.collect_column("pos")[0].as_py() @@ -1628,22 +1642,23 @@ def regexp_instr( >>> result = df.select( ... dfn.functions.regexp_instr( - ... dfn.col("a"), dfn.lit("(abc)"), - ... sub_expr=dfn.lit(1), + ... dfn.col("a"), "(abc)", + ... sub_expr=1, ... ).alias("pos") ... ) >>> result.collect_column("pos")[0].as_py() 1 """ - start = start.expr if start is not None else None - n = n.expr if n is not None else None - flags = flags.expr if flags is not None else None - sub_expr = sub_expr.expr if sub_expr is not None else None + regex = _to_raw_literal_expr(regex) + start = _to_raw_literal_expr(start) if start is not None else None + n = _to_raw_literal_expr(n) if n is not None else None + flags = _to_raw_literal_expr(flags) if flags is not None else None + sub_expr = _to_raw_literal_expr(sub_expr) if sub_expr is not None else None return Expr( f.regexp_instr( values.expr, - regex.expr, + regex, start, n, flags, diff --git a/python/tests/test_functions.py b/python/tests/test_functions.py index 11e94af1c..30c92768d 100644 --- a/python/tests/test_functions.py +++ b/python/tests/test_functions.py @@ -932,6 +932,30 @@ def test_map_functions(func, expected): f.regexp_count(column("a"), literal("(ell|orl)")), pa.array([1, 1, 0], type=pa.int64()), ), + ( + f.regexp_like(column("a"), "(ell|orl)"), + pa.array([True, True, False]), + ), + ( + f.regexp_match(column("a"), "(ell|orl)"), + pa.array([["ell"], ["orl"], None], type=pa.list_(pa.string_view())), + ), + ( + f.regexp_replace(column("a"), "(ell|orl)", "-"), + pa.array(["H-o", "W-d", "!"], type=pa.string_view()), + ), + ( + f.regexp_count(column("a"), "(ell|orl)", start=1), + pa.array([1, 1, 0], type=pa.int64()), + ), + ( + f.regexp_count(column("a"), "(ELL|ORL)", flags="i"), + pa.array([1, 1, 0], type=pa.int64()), + ), + ( + f.regexp_instr(column("a"), "([lr])", n=2), + pa.array([4, 4, 0], type=pa.int64()), + ), ( f.regexp_instr(column("a"), literal("(ell|orl)")), pa.array([2, 2, 0], type=pa.int64()),