diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 3ca4edafa6873..d7ac6d89c045a 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2068,7 +2068,11 @@ def slice(x, start, length): [Row(sliced=[2, 3]), Row(sliced=[5])] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.slice(_to_java_column(x), start, length)) + return Column(sc._jvm.functions.slice( + _to_java_column(x), + start._jc if isinstance(start, Column) else start, + length._jc if isinstance(length, Column) else length + )) @since(2.4) diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index 7dcc19f3ba45d..02180daf081ec 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -292,6 +292,16 @@ def test_input_file_name_reset_for_rdd(self): for result in results: self.assertEqual(result[0], '') + def test_slice(self): + from pyspark.sql.functions import slice, lit + + df = self.spark.createDataFrame([([1, 2, 3],), ([4, 5],)], ['x']) + + self.assertEquals( + df.select(slice(df.x, 2, 2).alias("sliced")).collect(), + df.select(slice(df.x, lit(2), lit(2)).alias("sliced")).collect(), + ) + def test_array_repeat(self): from pyspark.sql.functions import array_repeat, lit