In [0]:
# Enable Arrow support.
spark.conf.set("spark.sql.execution.arrow.enabled", "true")

In [0]:
input_data = spark.read.format("parquet").load("s3://air-example-data-2/10G-image-data-synthetic-raw-parquet")
# Force execution of the read
input_data.write.mode("overwrite").format("noop").save()

In [0]:
# More parallelism than data partitions.
print("# data partitions: ", input_data.rdd.getNumPartitions())
print("# Spark max parallelism: ", sc.defaultParallelism)

# data partitions:  32
# Spark max parallelism:  64


In [0]:
from pyspark.sql.functions import col, pandas_udf
from pyspark.sql.types import ArrayType, FloatType

import pandas as pd
import time

@pandas_udf(ArrayType(FloatType()))
def dummy_preprocess(image: pd.Series) -> pd.Series:
    time.sleep(1)
    return image



In [0]:
# Preprocess with a 1 second sleep
# Since the parallelism is more than data partitions, all partitions should run in parallel.
dummy_preprocessed_data = input_data.select(dummy_preprocess(col("image")))

In [0]:
# Force execution of preprocessing

start_time = time.time()
dummy_preprocessed_data.write.mode("overwrite").format("noop").save()
end_time = time.time()
print(f"Preprocessing took: {end_time-start_time} seconds")

Preprocessing took: 84.83807134628296 seconds


In [0]:
sc.show_profiles()

Profile of UDF<id=56>
         48 function calls in 16.015 seconds

   Ordered by: internal time, cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
       16   16.014    1.001   16.014    1.001 {built-in method time.sleep}
       16    0.000    0.000   16.015    1.001 <command-566047737056994>:7(dummy_preprocess)
       16    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}


