Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
[SPARK-35382][PYTHON] Fix lambda variable name issues in nested DataF…
…rame functions in Python APIs

### What changes were proposed in this pull request?

This PR fixes the same issue as #32424.

```py
from pyspark.sql.functions import flatten, struct, transform
df = spark.sql("SELECT array(1, 2, 3) as numbers, array('a', 'b', 'c') as letters")
df.select(flatten(
    transform(
        "numbers",
        lambda number: transform(
            "letters",
            lambda letter: struct(number.alias("n"), letter.alias("l"))
        )
    )
).alias("zipped")).show(truncate=False)
```

**Before:**

```
+------------------------------------------------------------------------+
|zipped                                                                  |
+------------------------------------------------------------------------+
|[{a, a}, {b, b}, {c, c}, {a, a}, {b, b}, {c, c}, {a, a}, {b, b}, {c, c}]|
+------------------------------------------------------------------------+
```

**After:**

```
+------------------------------------------------------------------------+
|zipped                                                                  |
+------------------------------------------------------------------------+
|[{1, a}, {1, b}, {1, c}, {2, a}, {2, b}, {2, c}, {3, a}, {3, b}, {3, c}]|
+------------------------------------------------------------------------+
```

### Why are the changes needed?

To produce the correct results.

### Does this PR introduce _any_ user-facing change?

Yes, it fixes the results to be correct as mentioned above.

### How was this patch tested?

Added a unit test as well as manually.

Closes #32523 from ueshin/issues/SPARK-35382/nested_higher_order_functions.

Authored-by: Takuya UESHIN <ueshin@databricks.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
(cherry picked from commit 17b59a9)
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
ueshin authored and HyukjinKwon committed May 13, 2021
1 parent 82e461a commit 67e4c94
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
5 changes: 4 additions & 1 deletion python/pyspark/sql/functions.py
Expand Up @@ -4153,7 +4153,10 @@ def _create_lambda(f):

argnames = ["x", "y", "z"]
args = [
_unresolved_named_lambda_variable(arg) for arg in argnames[: len(parameters)]
_unresolved_named_lambda_variable(
expressions.UnresolvedNamedLambdaVariable.freshVarName(arg)
)
for arg in argnames[: len(parameters)]
]

result = f(*args)
Expand Down
22 changes: 22 additions & 0 deletions python/pyspark/sql/tests/test_functions.py
Expand Up @@ -491,6 +491,28 @@ def test_higher_order_function_failures(self):
with self.assertRaises(ValueError):
transform(col("foo"), lambda x: 1)

def test_nested_higher_order_function(self):
# SPARK-35382: lambda vars must be resolved properly in nested higher order functions
from pyspark.sql.functions import flatten, struct, transform

df = self.spark.sql("SELECT array(1, 2, 3) as numbers, array('a', 'b', 'c') as letters")

actual = df.select(flatten(
transform(
"numbers",
lambda number: transform(
"letters",
lambda letter: struct(number.alias("n"), letter.alias("l"))
)
)
)).first()[0]

expected = [(1, "a"), (1, "b"), (1, "c"),
(2, "a"), (2, "b"), (2, "c"),
(3, "a"), (3, "b"), (3, "c")]

self.assertEquals(actual, expected)

def test_window_functions(self):
df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
w = Window.partitionBy("value").orderBy("key")
Expand Down

0 comments on commit 67e4c94

Please sign in to comment.