Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 3 additions & 32 deletions python/pyspark/sql/tests/test_udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,17 +164,14 @@ class TestUDTF:
def eval(self, a: int):
...

# TODO(SPARK-43967): Support Python UDTFs with empty return values
with self.assertRaisesRegex(PythonException, "TypeError"):
TestUDTF(lit(1)).collect()
self.assertEqual(TestUDTF(lit(1)).collect(), [])

@udtf(returnType="a: int")
class TestUDTF:
def eval(self, a: int):
return

with self.assertRaisesRegex(PythonException, "TypeError"):
TestUDTF(lit(1)).collect()
self.assertEqual(TestUDTF(lit(1)).collect(), [])

def test_udtf_with_conditional_return(self):
class TestUDTF:
Expand All @@ -195,9 +192,7 @@ class TestUDTF:
def eval(self, a: int):
yield

# TODO(SPARK-43967): Support Python UDTFs with empty return values
with self.assertRaisesRegex(Py4JJavaError, "java.lang.NullPointerException"):
TestUDTF(lit(1)).collect()
assertDataFrameEqual(TestUDTF(lit(1)), [Row(a=None)])

def test_udtf_with_none_output(self):
@udtf(returnType="a: int")
Expand Down Expand Up @@ -807,21 +802,6 @@ def eval(self, a: int):
func = udtf(TestUDTF, returnType="a: int")
self.assertEqual(func(lit(1)).collect(), [Row(a=1)])

def test_udtf_eval_with_no_return(self):
@udtf(returnType="a: int")
class TestUDTF:
def eval(self, a: int):
...

self.assertEqual(TestUDTF(lit(1)).collect(), [])

@udtf(returnType="a: int")
class TestUDTF:
def eval(self, a: int):
return

self.assertEqual(TestUDTF(lit(1)).collect(), [])

def test_udtf_terminate_with_wrong_num_output(self):
# The error message for arrow-optimized UDTF is different from regular UDTF.
err_msg = "The number of columns in the result does not match the specified schema."
Expand All @@ -848,15 +828,6 @@ def terminate(self):
with self.assertRaisesRegex(PythonException, err_msg):
TestUDTF(lit(1)).show()

def test_udtf_with_empty_yield(self):
@udtf(returnType="a: int")
class TestUDTF:
def eval(self, a: int):
yield

# Arrow-optimized UDTF can support this.
self.assertEqual(TestUDTF(lit(1)).collect(), [Row(a=None)])

def test_udtf_with_wrong_num_output(self):
# The error message for arrow-optimized UDTF is different.
err_msg = "The number of columns in the result does not match the specified schema."
Expand Down
20 changes: 17 additions & 3 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,21 @@ def mapper(_, it):
def wrap_udtf(f, return_type):
assert return_type.needConversion()
toInternal = return_type.toInternal
return lambda *a: map(toInternal, f(*a))

# Evaluate the function and return a tuple back to the executor.
def evaluate(*a) -> tuple:
res = f(*a)
if res is None:
# If the function returns None or does not have an explicit return statement,
# an empty tuple is returned to the executor.
# This is because directly constructing tuple(None) results in an exception.
return tuple()
else:
# If the function returns a result, we map it to the internal representation and
# returns the results as a tuple.
return tuple(map(toInternal, res))

return evaluate

eval = wrap_udtf(getattr(udtf, "eval"), return_type)

Expand All @@ -592,11 +606,11 @@ def wrap_udtf(f, return_type):
def mapper(_, it):
try:
for a in it:
yield tuple(eval(*[a[o] for o in arg_offsets]))
yield eval(*[a[o] for o in arg_offsets])
finally:
if terminate is not None:
try:
yield tuple(terminate())
yield terminate()
except BaseException as e:
raise PySparkRuntimeError(
error_class="UDTF_EXEC_ERROR",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ trait EvalPythonUDTFExec extends UnaryExecNode {
}

val joined = new JoinedRow
val nullRow = new GenericInternalRow(udtf.elementSchema.length)
val resultProj = UnsafeProjection.create(output, output)

outputRowIterator.flatMap { outputRows =>
Expand All @@ -118,7 +119,15 @@ trait EvalPythonUDTFExec extends UnaryExecNode {
// from the UDTF are from the `terminate()` call. We leave the left side as the last
// element of its child output to keep it consistent with the Generate implementation
// and Hive UDTFs.
outputRows.map(r => resultProj(joined.withRight(r)))
outputRows.map { r =>
// When the UDTF's result is None, such as `def eval(): yield`,
// we join it with a null row to avoid NullPointerException.
if (r == null) {
resultProj(joined.withRight(nullRow))
} else {
resultProj(joined.withRight(r))
}
}
}
}
}
Expand Down