# spark sedona prefect workflow template


## 1. Why we cannot Use a Global SparkSession in Prefect

### 1.1 Process boundaries

In our case, we use a `process work-pool`. Each task runs in a separate subprocess (multiprocessing)

Global variables in workflow definition(main process) do not cross process boundaries, which means the spark session is not visible inside task worker processes(subprocess).

### 1.2 SparkSession cannot be pickled

You may ask if the main process sent the spark session to subprocess, there will be no more problems.

Prefect serializes parameters and state using Pydantic + cloudpickle.

A SparkSession contains:
- Java gateway references
- Socket connections
- JVM pointers
- Py4J handles

These cannot be serialized. A task receiving a SparkSession argument will raise: `A task receiving a SparkSession argument will raise:`

### 1.3 Spark is NOT thread-safe.

Two tasks using one Spark session at the same time â†’ race conditions:
- JVM deadlocks
- Out-of-memory crashes
- Inconsistent job runs
- Spark UI showing mixed jobs from different tasks

### 1.4 Global SparkSession breaks worker restart & resiliency

If a worker crashes:

- Prefect restarts the process
- Any global SparkSession state is gone
- Tasks that depend on that global session cannot recover

Prefect's engine assumes tasks are stateless and reproducible. Global spark session state violates this model.

## Conclusion: we must create a spark session in each task

## Best practices

To avoid duplicate the spark session creation code, we need a factory function which can build spark session automatically



In [2]:
import datetime
import tempfile
from pyspark.sql import SparkSession
from pathlib import Path
import time
import shutil

def build_spark_temp_dir() -> str:
    """
    This function creates a temporary directory for Spark temporary files. Each Spark session uses a unique temp folder will avoid
    file-lock conflicts on Windows. Folder can be safely deleted after spark.stop().
    An example output C:/Users/alice/AppData/Local/spark_temp/20250329_124501_839201
    :return: the temporary directory path in string
    """
    ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
    base = Path(tempfile.gettempdir())
    spark_temp_dir = base / "spark_temp" / f"spark_{ts}"
    spark_temp_dir.mkdir(parents=True, exist_ok=True)
    return spark_temp_dir.as_posix()


In [3]:
def build_spark_session(app_name: str, extra_jar_folder_path: str | None = None, driver_memory: str = "4g", local_dir_path: str | None = None):
    """
    This function builds a SparkSession with standardized configuration and extra dependencies.

    :param app_name:
    :param extra_jar_folder_path:
    :param driver_memory:
    :param local_dir_path:
    :return:
    """
    # create a standard spark session builder
    builder = (
        SparkSession.builder
        .appName(app_name)
        .master("local[4]")
        .config("spark.driver.memory", driver_memory)
        # enable AQE
        .config("spark.sql.adaptive.enabled", "true")
        .config("spark.sql.adaptive.coalescePartitions.enabled", "true")
        # give a partition size advice.
        .config("spark.sql.adaptive.advisoryPartitionSizeInBytes", "128MB")
        # set AQE partition range
        .config("spark.sql.adaptive.maxNumPostShufflePartitions", "100")
        .config("spark.sql.adaptive.minNumPostShufflePartitions", "1")
        # increase worker timeout
        .config("spark.network.timeout", "800s")
        .config("spark.executor.heartbeatInterval", "90s")
        .config("spark.sql.sources.commitProtocolClass",
                "org.apache.spark.sql.execution.datasources.SQLHadoopMapReduceCommitProtocol")
        # JVM memory allocation
        .config("spark.driver.maxResultSize", "4g")  # Avoid OOM on collect()
        # Shuffle & partition tuning
        .config("spark.sql.files.maxPartitionBytes", "128m")  # Avoid large partitions in memory
        .config("spark.reducer.maxSizeInFlight", "48m")  # Limit shuffle buffer
        # Unified memory management
        .config("spark.memory.fraction", "0.7")  # Reduce pressure on execution memory
        .config("spark.memory.storageFraction", "0.3")  # Smaller cache area
        # Spill to disk early instead of crashing
        .config("spark.shuffle.spill", "true")
        .config("spark.shuffle.spill.compress", "true")
        .config("spark.shuffle.compress", "true")
        # optimize jvm GC
        .config("spark.driver.extraJavaOptions",
                "-XX:+UseG1GC -XX:InitiatingHeapOccupancyPercent=35 -XX:+HeapDumpOnOutOfMemoryError")
        # Use Kryo serializer
        .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
        # Optional: buffer size for serialization
        .config("spark.kryoserializer.buffer", "64m")
        .config("spark.kryoserializer.buffer.max", "512m")
    )

    # Ensure local_dir exists
    # if local dir path are not provided, generate a standard temp folder
    if local_dir_path is None:
        local_dir_path = build_spark_temp_dir()
    else:
        local_dir = Path(local_dir_path)
        local_dir.mkdir(parents=True, exist_ok=True)
    builder = builder.config("spark.local.dir", local_dir_path)

    # Load extra JARs only when present
    if extra_jar_folder_path:
        jar_dir = Path(extra_jar_folder_path)
        jar_files = [
            str(p) for p in jar_dir.iterdir()
            if p.is_file() and p.suffix == ".jar"
        ]
        if jar_files:
            builder = builder.config("spark.jars", ",".join(jar_files))

    return builder.getOrCreate(), local_dir_path

In [14]:
def safe_delete(target_path: str, retries: int = 5, delay: float = 0.5):
    """
    This function cleans the spark temp dir after the spark session is closed
    :param target_path:
    :param retries:
    :param delay:
    :return:
    """
    target = Path(target_path)

    for attempt in range(1, retries + 1):
        if not target.exists():
            return True     # Already deleted

        try:
            shutil.rmtree(target)
        except PermissionError:
            time.sleep(delay * attempt)   # exponential backoff

        # Double-check if deletion actually succeeded
        if not target.exists():
            return True

    return False  # Explicit failure after retries

In [5]:
# here we choose sedona 1.8.0 for spark 3.5.* build with scala 2.12
project_root_dir = Path.cwd().parent
sedona_version = "sedona-35-212-180"
jar_folder_path = f"{project_root_dir}/jars/{sedona_version}"

In [6]:
spark, spark_temp_dir_path = build_spark_session("my-sedona-demo", jar_folder_path, driver_memory="6g")

In [7]:
from sedona.spark import SedonaContext

# create a sedona context
sedona = SedonaContext.create(spark)

In [8]:
data_dir = "../data"
src_file = f"{data_dir}/source/word_raw.txt"
out_file = f"{data_dir}/out/wc_out"

df = spark.read.text(src_file)

df.show()

+--------------------+
|               value|
+--------------------+
|This is a test da...|
|   data file updated|
|data is the new b...|
+--------------------+



In [9]:

from pyspark.sql.functions import explode, split, col

words = df.select(explode(split(col(df.columns[0]), "\\s+")).alias("word"))
counts = words.groupBy("word").count()


counts.show()

+-------+-----+
|   word|count|
+-------+-----+
|  gold.|    1|
|    new|    1|
|    for|    1|
|     is|    2|
|updated|    1|
|   data|    3|
| count.|    1|
|   file|    2|
|    the|    1|
|   word|    1|
|  black|    1|
|   This|    1|
|      a|    1|
|   test|    1|
+-------+-----+



In [10]:
counts.write.mode("overwrite").csv(out_file)

In [11]:
spark.stop()

In [17]:
safe_delete(spark_temp_dir_path)
safe_delete("C:/Users/pliu/AppData/Local/Temp/spark_temp/spark_20251204_161815_299461")

False

In [16]:
print(spark_temp_dir_path)

C:/Users/pliu/AppData/Local/Temp/spark_temp/spark_20251204_161815_299461
