In [3]:
import os
os.environ['JAVA_HOME'] = '/usr/lib/jvm/java-21-openjdk-amd64'
# Make sure checkpoint directory exists
os.makedirs("checkpoints", exist_ok=True)
from pyspark.sql import SparkSession
# Then run
spark = SparkSession.builder \
    .appName("PureSparkShortestPath") \
    .config("spark.driver.memory", "8g") \
    .config("spark.executor.memory", "4g") \
    .getOrCreate()

spark.sparkContext.setCheckpointDir("checkpoints")

25/11/16 16:52:55 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.


In [4]:
"""
Pure Spark SQL implementation for all-pairs shortest path calculation.
Best for: Large groups (500+ rows per partition) or when pandas is unavailable.
Uses distributed Spark operations (joins, aggregations) throughout.
"""

from pyspark.sql import SparkSession, DataFrame
from pyspark.sql import functions as F


def compute_shortest_paths_for_group(df_group: DataFrame, max_iterations: int = None) -> DataFrame:
    """
    Computes all-pairs shortest paths for a single group using pure Spark SQL.
    Uses iterative path extension with self-joins.
    
    Input schema: (src, dst, cost, next_node)
    Output schema: (src, dst, cost, next_node)
    
    Args:
        df_group: DataFrame containing edges/shortcuts for a single group
        max_iterations: Maximum iterations (default: row count of group)
    
    Returns:
        DataFrame with shortest paths (all pairs reachable within the group)
    """
    if max_iterations is None:
        max_iterations = df_group.count()
    
    # Rename to standard internal names
    current_paths = df_group.select(
        F.col("src").alias("path_start"),
        F.col("dst").alias("path_end"),
        F.col("cost"),
        F.col("next_node")
    ).cache()
    
    iteration = 0
    
    while iteration < max_iterations:
        iteration += 1
        
        # Extend paths by joining: if we have A->B and B->C, create A->C
        new_paths = current_paths.alias("L").join(
            current_paths.alias("R"),
            F.col("L.path_end") == F.col("R.path_start"),
            "inner"
        ).select(
            F.col("L.path_start").alias("path_start"),
            F.col("R.path_end").alias("path_end"),
            (F.col("L.cost") + F.col("R.cost")).alias("cost"),
            F.col("L.next_node").alias("next_node")  # Keep first hop
        ).filter(F.col("path_start") != F.col("path_end"))  # No self-loops
        
        # Combine with existing paths
        all_paths = current_paths.unionByName(new_paths)
        
        # Keep only minimum cost path for each (start, end) pair
        min_costs = all_paths.groupBy("path_start", "path_end").agg(
            F.min("cost").alias("min_cost")
        )
        
        next_paths = all_paths.alias("T1").join(
            min_costs.alias("T2"),
            (F.col("T1.path_start") == F.col("T2.path_start")) &
            (F.col("T1.path_end") == F.col("T2.path_end")) &
            (F.col("T1.cost") == F.col("T2.min_cost")),
            "inner"
        ).select(
            F.col("T1.path_start"),
            F.col("T1.path_end"),
            F.col("T2.min_cost").alias("cost"),
            F.col("T1.next_node")
        ).dropDuplicates(["path_start", "path_end"]).cache()
        
        # Check convergence
        if next_paths.count() == current_paths.count():
            diff = current_paths.subtract(next_paths).count()
            if diff == 0:
                current_paths.unpersist()
                # Rename back to original schema
                return next_paths.select(
                    F.col("path_start").alias("src"),
                    F.col("path_end").alias("dst"),
                    F.col("cost"),
                    F.col("next_node")
                )
        
        current_paths.unpersist()
        current_paths = next_paths
    
    # Return with original column names
    return current_paths.select(
        F.col("path_start").alias("src"),
        F.col("path_end").alias("dst"),
        F.col("cost"),
        F.col("next_node")
    )


def compute_shortest_paths_per_partition(
    df_shortcuts: DataFrame,
    partition_columns: list,
    max_iterations_per_group: int = None,
    checkpoint_interval: int = 5
) -> DataFrame:
    """
    Computes all-pairs shortest paths for each partition using pure Spark SQL.
    Processes each partition separately to avoid cross-partition joins.
    
    Args:
        df_shortcuts: DataFrame with schema (src, dst, cost, next_node, partition_col1, ...)
        partition_columns: List of column names to group by
        max_iterations_per_group: Maximum iterations per group
        checkpoint_interval: Checkpoint every N partitions to prevent lineage issues
    
    Returns:
        DataFrame with computed shortest paths, preserving partition columns
    """
    
    # Add a unique partition ID for easier tracking
    df_with_pid = df_shortcuts.withColumn(
        "_partition_id",
        F.concat_ws("_", *[F.col(c).cast("string") for c in partition_columns])
    )
    
    # Get all unique partition IDs
    partition_ids = [row["_partition_id"] for row in 
                     df_with_pid.select("_partition_id").distinct().collect()]
    
    print(f"Processing {len(partition_ids)} partitions...")
    
    results = []
    
    for idx, pid in enumerate(partition_ids):
        print(f"Processing partition {idx + 1}/{len(partition_ids)}: {pid}")
        
        # Extract data for this partition only
        partition_df = df_with_pid.filter(F.col("_partition_id") == pid)
        
        # Get partition column values (for adding back later)
        partition_values = partition_df.select(partition_columns).first().asDict()
        
        # Compute shortest paths for this partition
        shortest_paths = compute_shortest_paths_for_group(
            partition_df.select("src", "dst", "cost", "next_node"),
            max_iterations=max_iterations_per_group
        )
        
        # Add back partition columns
        for col, val in partition_values.items():
            shortest_paths = shortest_paths.withColumn(col, F.lit(val))
        
        results.append(shortest_paths)
        
        # Checkpoint periodically to prevent long lineage
        if checkpoint_interval and (idx + 1) % checkpoint_interval == 0:
            print(f"Checkpointing at partition {idx + 1}...")
    
    # Union all results
    if not results:
        return df_shortcuts.limit(0)  # Empty DataFrame with same schema
    
    final_result = results[0]
    for result_df in results[1:]:
        final_result = final_result.unionByName(result_df)
    
    return final_result


def merge_shortcuts_to_main_table(
    main_df: DataFrame,
    new_shortcuts: DataFrame,
    partition_columns: list
) -> DataFrame:
    """
    Merges newly computed shortcuts back into the main table.
    
    Strategy:
    1. Remove old shortcuts for the processed partitions
    2. Add new shortcuts
    
    Args:
        main_df: Main shortcut table
        new_shortcuts: Newly computed shortcuts
        partition_columns: Columns used for partitioning
    
    Returns:
        Updated DataFrame with new shortcuts
    """
    # Get the partition values that were recomputed
    recomputed_partitions = new_shortcuts.select(partition_columns).distinct()
    
    # Remove old shortcuts for these partitions using left_anti join
    remaining_df = main_df.alias("main").join(
        recomputed_partitions.alias("recomp"),
        [F.col(f"main.{col}") == F.col(f"recomp.{col}") for col in partition_columns],
        "left_anti"
    )
    
    # Combine with new shortcuts
    updated_df = remaining_df.unionByName(new_shortcuts)
    
    return updated_df


def example_pipeline(spark: SparkSession):
    """
    Example usage in an iterative pipeline.
    """
    
    # Example: Create initial shortcuts table
    data = [
        ("A", "B", 1.0, "B", "region1"),
        ("B", "C", 2.0, "C", "region1"),
        ("C", "D", 1.0, "D", "region1"),
        ("A", "C", 5.0, "C", "region1"),  # Suboptimal shortcut
        ("E", "F", 1.0, "F", "region2"),
        ("F", "G", 2.0, "G", "region2"),
    ]
    
    shortcuts = spark.createDataFrame(
        data, 
        ["src", "dst", "cost", "next_node", "region_id"]
    )
    
    print("Initial shortcuts:")
    shortcuts.show()
    
    # Iterative pipeline
    for iteration in range(3):
        print(f"\n{'='*60}")
        print(f"ITERATION {iteration + 1}")
        print(f"{'='*60}")
        
        # Step 1: Filter (example: keep only cost < 10)
        filtered = shortcuts.filter(F.col("cost") < 10)
        
        # Step 2: Compute shortest paths per partition
        new_shortcuts = compute_shortest_paths_per_partition(
            filtered,
            partition_columns=["region_id"],
            max_iterations_per_group=100
        )
        
        print(f"\nNew shortcuts computed:")
        new_shortcuts.show()
        
        # Step 3: Merge back to main table
        shortcuts = merge_shortcuts_to_main_table(
            shortcuts,
            new_shortcuts,
            partition_columns=["region_id"]
        )
        
        print(f"\nUpdated shortcuts table:")
        shortcuts.show()
    
    return shortcuts


if __name__ == "__main__":
    spark = SparkSession.builder \
        .appName("PureSparkShortestPath") \
        .config("spark.driver.memory", "4g") \
        .getOrCreate()
    
    spark.sparkContext.setCheckpointDir("checkpoints")
    
    result = example_pipeline(spark)
    
    print("\n" + "="*60)
    print("FINAL RESULT")
    print("="*60)
    result.show()
    
    spark.stop()

25/11/16 16:52:59 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.


Initial shortcuts:
+---+---+----+---------+---------+
|src|dst|cost|next_node|region_id|
+---+---+----+---------+---------+
|  A|  B| 1.0|        B|  region1|
|  B|  C| 2.0|        C|  region1|
|  C|  D| 1.0|        D|  region1|
|  A|  C| 5.0|        C|  region1|
|  E|  F| 1.0|        F|  region2|
|  F|  G| 2.0|        G|  region2|
+---+---+----+---------+---------+


ITERATION 1
Processing 2 partitions...
Processing partition 1/2: region1
Processing partition 2/2: region2

New shortcuts computed:
+---+---+----+---------+---------+
|src|dst|cost|next_node|region_id|
+---+---+----+---------+---------+
|  B|  C| 2.0|        C|  region1|
|  A|  C| 3.0|        B|  region1|
|  C|  D| 1.0|        D|  region1|
|  A|  D| 4.0|        B|  region1|
|  B|  D| 3.0|        C|  region1|
|  A|  B| 1.0|        B|  region1|
|  E|  F| 1.0|        F|  region2|
|  F|  G| 2.0|        G|  region2|
|  E|  G| 3.0|        F|  region2|
+---+---+----+---------+---------+


Updated shortcuts table:
+---+---+----+-

                                                                                

Processing partition 2/2: region2


                                                                                


New shortcuts computed:
+---+---+----+---------+---------+
|src|dst|cost|next_node|region_id|
+---+---+----+---------+---------+
|  B|  C| 2.0|        C|  region1|
|  A|  C| 3.0|        B|  region1|
|  C|  D| 1.0|        D|  region1|
|  A|  D| 4.0|        B|  region1|
|  B|  D| 3.0|        C|  region1|
|  A|  B| 1.0|        B|  region1|
|  E|  F| 1.0|        F|  region2|
|  F|  G| 2.0|        G|  region2|
|  E|  G| 3.0|        F|  region2|
+---+---+----+---------+---------+


Updated shortcuts table:


Py4JJavaError: An error occurred while calling o1967.showString.
: java.lang.OutOfMemoryError: Java heap space
	at java.base/java.util.Arrays.copyOf(Arrays.java:3482)
	at java.base/java.util.ArrayList.grow(ArrayList.java:237)
	at java.base/java.util.ArrayList.grow(ArrayList.java:244)
	at java.base/java.util.ArrayList.add(ArrayList.java:483)
	at java.base/java.util.ArrayList.add(ArrayList.java:496)
	at org.apache.spark.sql.catalyst.util.StringConcat.append(StringUtils.scala:47)
	at org.apache.spark.sql.execution.QueryExecution.$anonfun$explainString$1(QueryExecution.scala:312)
	at org.apache.spark.sql.execution.QueryExecution.$anonfun$explainString$1$adapted(QueryExecution.scala:312)
	at org.apache.spark.sql.execution.QueryExecution$$Lambda/0x0000746af0bf72c0.apply(Unknown Source)
	at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$generateTreeString$2(TreeNode.scala:1032)
	at org.apache.spark.sql.catalyst.trees.TreeNode$$Lambda/0x0000746af0cf73f8.apply$mcVI$sp(Unknown Source)
	at scala.collection.immutable.Range.foreach$mVc$sp(Range.scala:192)
	at org.apache.spark.sql.catalyst.trees.TreeNode.generateTreeString(TreeNode.scala:1030)
	at org.apache.spark.sql.catalyst.trees.TreeNode.generateTreeString(TreeNode.scala:1078)
	at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$generateTreeString$4(TreeNode.scala:1071)
	at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$generateTreeString$4$adapted(TreeNode.scala:1069)
	at org.apache.spark.sql.catalyst.trees.TreeNode$$Lambda/0x0000746af0cf7028.apply(Unknown Source)
	at scala.collection.immutable.Vector.foreach(Vector.scala:2125)
	at org.apache.spark.sql.catalyst.trees.TreeNode.generateTreeString(TreeNode.scala:1069)
	at org.apache.spark.sql.catalyst.trees.TreeNode.generateTreeString(TreeNode.scala:1078)
	at org.apache.spark.sql.catalyst.trees.TreeNode.generateTreeString(TreeNode.scala:1078)
	at org.apache.spark.sql.catalyst.trees.TreeNode.generateTreeString(TreeNode.scala:1078)
	at org.apache.spark.sql.catalyst.trees.TreeNode.generateTreeString(TreeNode.scala:1078)
	at org.apache.spark.sql.catalyst.trees.TreeNode.generateTreeString(TreeNode.scala:1078)
	at org.apache.spark.sql.catalyst.trees.TreeNode.generateTreeString(TreeNode.scala:1078)
	at org.apache.spark.sql.catalyst.trees.TreeNode.generateTreeString(TreeNode.scala:1078)
	at org.apache.spark.sql.catalyst.trees.TreeNode.generateTreeString(TreeNode.scala:1078)
	at org.apache.spark.sql.catalyst.trees.TreeNode.generateTreeString(TreeNode.scala:1078)
	at org.apache.spark.sql.catalyst.trees.TreeNode.generateTreeString(TreeNode.scala:1078)
	at org.apache.spark.sql.catalyst.trees.TreeNode.generateTreeString(TreeNode.scala:1078)
	at org.apache.spark.sql.catalyst.trees.TreeNode.generateTreeString(TreeNode.scala:1078)
	at org.apache.spark.sql.catalyst.trees.TreeNode.generateTreeString(TreeNode.scala:1078)


In [5]:
"""
Pure Spark SQL implementation for all-pairs shortest path calculation.
Best for: Large groups (500+ rows per partition) or when pandas is unavailable.
Uses distributed Spark operations (joins, aggregations) throughout.
"""

from pyspark.sql import SparkSession, DataFrame
from pyspark.sql import functions as F


def compute_shortest_paths_for_group(df_group: DataFrame, max_iterations: int = None, 
                                      checkpoint_interval: int = 3) -> DataFrame:
    """
    Computes all-pairs shortest paths for a single group using pure Spark SQL.
    Uses iterative path extension with self-joins.
    
    Input schema: (src, dst, cost, next_node)
    Output schema: (src, dst, cost, next_node)
    
    Args:
        df_group: DataFrame containing edges/shortcuts for a single group
        max_iterations: Maximum iterations (default: row count of group)
        checkpoint_interval: Checkpoint every N iterations to prevent lineage explosion
    
    Returns:
        DataFrame with shortest paths (all pairs reachable within the group)
    """
    if max_iterations is None:
        max_iterations = df_group.count()
    
    # Rename to standard internal names
    current_paths = df_group.select(
        F.col("src").alias("path_start"),
        F.col("dst").alias("path_end"),
        F.col("cost"),
        F.col("next_node")
    )
    
    # Materialize initial data
    current_paths = current_paths.persist()
    current_paths.count()  # Force materialization
    
    iteration = 0
    
    while iteration < max_iterations:
        iteration += 1
        print(f"  Iteration {iteration}...")
        
        # Extend paths by joining: if we have A->B and B->C, create A->C
        new_paths = current_paths.alias("L").join(
            current_paths.alias("R"),
            F.col("L.path_end") == F.col("R.path_start"),
            "inner"
        ).select(
            F.col("L.path_start").alias("path_start"),
            F.col("R.path_end").alias("path_end"),
            (F.col("L.cost") + F.col("R.cost")).alias("cost"),
            F.col("L.next_node").alias("next_node")  # Keep first hop
        ).filter(F.col("path_start") != F.col("path_end"))  # No self-loops
        
        # Combine with existing paths
        all_paths = current_paths.unionByName(new_paths)
        
        # Keep only minimum cost path for each (start, end) pair
        min_costs = all_paths.groupBy("path_start", "path_end").agg(
            F.min("cost").alias("min_cost")
        )
        
        next_paths = all_paths.alias("T1").join(
            min_costs.alias("T2"),
            (F.col("T1.path_start") == F.col("T2.path_start")) &
            (F.col("T1.path_end") == F.col("T2.path_end")) &
            (F.col("T1.cost") == F.col("T2.min_cost")),
            "inner"
        ).select(
            F.col("T1.path_start"),
            F.col("T1.path_end"),
            F.col("T2.min_cost").alias("cost"),
            F.col("T1.next_node")
        ).dropDuplicates(["path_start", "path_end"])
        
        # Checkpoint periodically to break lineage
        if iteration % checkpoint_interval == 0:
            next_paths = next_paths.checkpoint()
            print(f"  Checkpointed at iteration {iteration}")
        
        # Materialize and cache
        next_paths = next_paths.persist()
        next_count = next_paths.count()  # Force materialization
        
        # Check convergence
        current_count = current_paths.count()
        
        if next_count == current_count:
            diff = current_paths.subtract(next_paths).count()
            if diff == 0:
                print(f"  Converged at iteration {iteration}")
                current_paths.unpersist()
                # Rename back to original schema
                return next_paths.select(
                    F.col("path_start").alias("src"),
                    F.col("path_end").alias("dst"),
                    F.col("cost"),
                    F.col("next_node")
                )
        
        # Unpersist old data and move to next iteration
        current_paths.unpersist()
        current_paths = next_paths
    
    # Return with original column names
    return current_paths.select(
        F.col("path_start").alias("src"),
        F.col("path_end").alias("dst"),
        F.col("cost"),
        F.col("next_node")
    )


def compute_shortest_paths_per_partition(
    df_shortcuts: DataFrame,
    partition_columns: list,
    max_iterations_per_group: int = None
) -> DataFrame:
    """
    Computes all-pairs shortest paths for each partition using pure Spark SQL.
    Processes each partition separately to avoid cross-partition joins.
    
    Args:
        df_shortcuts: DataFrame with schema (src, dst, cost, next_node, partition_col1, ...)
        partition_columns: List of column names to group by
        max_iterations_per_group: Maximum iterations per group
    
    Returns:
        DataFrame with computed shortest paths, preserving partition columns
    """
    
    # Add a unique partition ID for easier tracking
    df_with_pid = df_shortcuts.withColumn(
        "_partition_id",
        F.concat_ws("_", *[F.col(c).cast("string") for c in partition_columns])
    )
    
    # Get all unique partition IDs
    partition_ids = [row["_partition_id"] for row in 
                     df_with_pid.select("_partition_id").distinct().collect()]
    
    print(f"Processing {len(partition_ids)} partitions...")
    
    results = []
    
    for idx, pid in enumerate(partition_ids):
        print(f"\nProcessing partition {idx + 1}/{len(partition_ids)}: {pid}")
        
        # Extract data for this partition only
        partition_df = df_with_pid.filter(F.col("_partition_id") == pid)
        
        # Get partition column values (for adding back later)
        partition_values = partition_df.select(partition_columns).first().asDict()
        
        # Compute shortest paths for this partition
        shortest_paths = compute_shortest_paths_for_group(
            partition_df.select("src", "dst", "cost", "next_node"),
            max_iterations=max_iterations_per_group,
            checkpoint_interval=3
        )
        
        # Add back partition columns
        for col, val in partition_values.items():
            shortest_paths = shortest_paths.withColumn(col, F.lit(val))
        
        # Materialize this partition's results
        shortest_paths = shortest_paths.persist()
        shortest_paths.count()  # Force materialization
        
        results.append(shortest_paths)
    
    # Union all results
    if not results:
        return df_shortcuts.limit(0)  # Empty DataFrame with same schema
    
    print(f"\nCombining results from {len(results)} partitions...")
    final_result = results[0]
    for result_df in results[1:]:
        final_result = final_result.unionByName(result_df)
    
    # Final materialization
    final_result = final_result.persist()
    final_result.count()
    
    # Clean up intermediate results
    for result_df in results:
        result_df.unpersist()
    
    return final_result


def merge_shortcuts_to_main_table(
    main_df: DataFrame,
    new_shortcuts: DataFrame,
    partition_columns: list
) -> DataFrame:
    """
    Merges newly computed shortcuts back into the main table.
    
    Strategy:
    1. Remove old shortcuts for the processed partitions
    2. Add new shortcuts
    
    Args:
        main_df: Main shortcut table
        new_shortcuts: Newly computed shortcuts
        partition_columns: Columns used for partitioning
    
    Returns:
        Updated DataFrame with new shortcuts
    """
    # Get the partition values that were recomputed
    recomputed_partitions = new_shortcuts.select(partition_columns).distinct()
    
    # Remove old shortcuts for these partitions using left_anti join
    remaining_df = main_df.alias("main").join(
        recomputed_partitions.alias("recomp"),
        [F.col(f"main.{col}") == F.col(f"recomp.{col}") for col in partition_columns],
        "left_anti"
    )
    
    # Combine with new shortcuts
    updated_df = remaining_df.unionByName(new_shortcuts)
    
    return updated_df


def example_pipeline(spark: SparkSession):
    """
    Example usage in an iterative pipeline.
    """
    
    # Example: Create initial shortcuts table
    data = [
        ("A", "B", 1.0, "B", "region1"),
        ("B", "C", 2.0, "C", "region1"),
        ("C", "D", 1.0, "D", "region1"),
        ("A", "C", 5.0, "C", "region1"),  # Suboptimal shortcut
        ("E", "F", 1.0, "F", "region2"),
        ("F", "G", 2.0, "G", "region2"),
    ]
    
    shortcuts = spark.createDataFrame(
        data, 
        ["src", "dst", "cost", "next_node", "region_id"]
    )
    
    print("Initial shortcuts:")
    shortcuts.show()
    
    # Iterative pipeline
    for iteration in range(3):
        print(f"\n{'='*60}")
        print(f"ITERATION {iteration + 1}")
        print(f"{'='*60}")
        
        # Step 1: Filter (example: keep only cost < 10)
        filtered = shortcuts.filter(F.col("cost") < 10)
        
        # Step 2: Compute shortest paths per partition
        new_shortcuts = compute_shortest_paths_per_partition(
            filtered,
            partition_columns=["region_id"],
            max_iterations_per_group=100
        )
        
        print(f"\nNew shortcuts computed:")
        new_shortcuts.show()
        
        # Step 3: Merge back to main table
        shortcuts = merge_shortcuts_to_main_table(
            shortcuts,
            new_shortcuts,
            partition_columns=["region_id"]
        )
        
        print(f"\nUpdated shortcuts table:")
        shortcuts.show()
    
    return shortcuts


if __name__ == "__main__":
    # Increase driver memory to handle larger plans
    spark = SparkSession.builder \
        .appName("PureSparkShortestPath") \
        .config("spark.driver.memory", "8g") \
        .config("spark.executor.memory", "4g") \
        .config("spark.sql.adaptive.enabled", "true") \
        .config("spark.sql.adaptive.coalescePartitions.enabled", "true") \
        .getOrCreate()
    
    spark.sparkContext.setCheckpointDir("checkpoints")
    
    print("="*60)
    print("PURE SPARK SQL SHORTEST PATH CALCULATOR")
    print("="*60)
    
    result = example_pipeline(spark)
    
    print("\n" + "="*60)
    print("FINAL RESULT")
    print("="*60)
    result.show()
    
    spark.stop()

25/11/16 16:54:50 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.


PURE SPARK SQL SHORTEST PATH CALCULATOR
Initial shortcuts:
+---+---+----+---------+---------+
|src|dst|cost|next_node|region_id|
+---+---+----+---------+---------+
|  A|  B| 1.0|        B|  region1|
|  B|  C| 2.0|        C|  region1|
|  C|  D| 1.0|        D|  region1|
|  A|  C| 5.0|        C|  region1|
|  E|  F| 1.0|        F|  region2|
|  F|  G| 2.0|        G|  region2|
+---+---+----+---------+---------+


ITERATION 1
Processing 2 partitions...

Processing partition 1/2: region1
  Iteration 1...


                                                                                

  Iteration 2...


                                                                                

  Iteration 3...


                                                                                

  Checkpointed at iteration 3


                                                                                

  Converged at iteration 3

Processing partition 2/2: region2
  Iteration 1...


                                                                                

  Iteration 2...


                                                                                

  Converged at iteration 2


                                                                                


Combining results from 2 partitions...


                                                                                


New shortcuts computed:
+---+---+----+---------+---------+
|src|dst|cost|next_node|region_id|
+---+---+----+---------+---------+
|  A|  B| 1.0|        B|  region1|
|  A|  C| 3.0|        B|  region1|
|  A|  D| 4.0|        B|  region1|
|  B|  C| 2.0|        C|  region1|
|  B|  D| 3.0|        C|  region1|
|  C|  D| 1.0|        D|  region1|
|  E|  F| 1.0|        F|  region2|
|  F|  G| 2.0|        G|  region2|
|  E|  G| 3.0|        F|  region2|
+---+---+----+---------+---------+


Updated shortcuts table:


                                                                                

+---+---+----+---------+---------+
|src|dst|cost|next_node|region_id|
+---+---+----+---------+---------+
|  A|  B| 1.0|        B|  region1|
|  A|  C| 3.0|        B|  region1|
|  A|  D| 4.0|        B|  region1|
|  B|  C| 2.0|        C|  region1|
|  B|  D| 3.0|        C|  region1|
|  C|  D| 1.0|        D|  region1|
|  E|  F| 1.0|        F|  region2|
|  F|  G| 2.0|        G|  region2|
|  E|  G| 3.0|        F|  region2|
+---+---+----+---------+---------+


ITERATION 2


                                                                                

Processing 2 partitions...

Processing partition 1/2: region1


                                                                                

  Iteration 1...


                                                                                

  Converged at iteration 1


                                                                                


Processing partition 2/2: region2


                                                                                

  Iteration 1...


                                                                                

  Converged at iteration 1


Exception in thread "RemoteBlock-temp-file-clean-thread" java.lang.OutOfMemoryError: Java heap space
	at java.base/java.lang.invoke.DirectMethodHandle.allocateInstance(DirectMethodHandle.java:501)
	at java.base/java.lang.invoke.DirectMethodHandle$Holder.newInvokeSpecial(DirectMethodHandle$Holder)
	at java.base/java.lang.invoke.Invokers$Holder.linkToTargetMethod(Invokers$Holder)
	at org.apache.spark.storage.BlockManager$RemoteBlockDownloadFileManager.org$apache$spark$storage$BlockManager$RemoteBlockDownloadFileManager$$keepCleaning(BlockManager.scala:2274)
	at org.apache.spark.storage.BlockManager$RemoteBlockDownloadFileManager$$anon$2.run(BlockManager.scala:2240)


Py4JJavaError: An error occurred while calling o2941.count.
: java.lang.OutOfMemoryError: Java heap space
	at java.base/java.lang.AbstractStringBuilder.<init>(AbstractStringBuilder.java:101)
	at java.base/java.lang.StringBuilder.<init>(StringBuilder.java:119)
	at org.apache.spark.sql.catalyst.util.StringConcat.toString(StringUtils.scala:63)
	at org.apache.spark.sql.catalyst.util.StringUtils$PlanStringConcat.toString(StringUtils.scala:149)
	at org.apache.spark.sql.execution.QueryExecution.explainString(QueryExecution.scala:314)
	at org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec.onUpdatePlan(AdaptiveSparkPlanExec.scala:858)
	at org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec.$anonfun$withFinalPlanUpdate$2(AdaptiveSparkPlanExec.scala:292)
	at org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec$$Lambda/0x0000746af0f9c870.apply$mcVJ$sp(Unknown Source)
	at scala.runtime.java8.JFunction1$mcVJ$sp.apply(JFunction1$mcVJ$sp.scala:18)
	at scala.Option.foreach(Option.scala:437)
	at org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec.$anonfun$withFinalPlanUpdate$1(AdaptiveSparkPlanExec.scala:292)
	at org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec$$Lambda/0x0000746af0f8cd00.apply$mcV$sp(Unknown Source)
	at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.scala:18)
	at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:804)
	at org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec.withFinalPlanUpdate(AdaptiveSparkPlanExec.scala:279)
	at org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec.executeCollect(AdaptiveSparkPlanExec.scala:402)
	at org.apache.spark.sql.classic.Dataset.$anonfun$count$1(Dataset.scala:1500)
	at org.apache.spark.sql.classic.Dataset.$anonfun$count$1$adapted(Dataset.scala:1499)
	at org.apache.spark.sql.classic.Dataset$$Lambda/0x0000746af11043c8.apply(Unknown Source)
	at org.apache.spark.sql.classic.Dataset.$anonfun$withAction$2(Dataset.scala:2234)
	at org.apache.spark.sql.classic.Dataset$$Lambda/0x0000746af0d05d40.apply(Unknown Source)
	at org.apache.spark.sql.execution.QueryExecution$.withInternalError(QueryExecution.scala:654)
	at org.apache.spark.sql.classic.Dataset.$anonfun$withAction$1(Dataset.scala:2232)
	at org.apache.spark.sql.classic.Dataset$$Lambda/0x0000746af0bebb58.apply(Unknown Source)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId0$8(SQLExecution.scala:163)
	at org.apache.spark.sql.execution.SQLExecution$$$Lambda/0x0000746af0bf5748.apply(Unknown Source)
	at org.apache.spark.sql.execution.SQLExecution$.withSessionTagsApplied(SQLExecution.scala:272)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId0$7(SQLExecution.scala:125)
	at org.apache.spark.sql.execution.SQLExecution$$$Lambda/0x0000746af0bf4b30.apply(Unknown Source)
	at org.apache.spark.JobArtifactSet$.withActiveJobArtifactState(JobArtifactSet.scala:94)
	at org.apache.spark.sql.artifact.ArtifactManager.$anonfun$withResources$1(ArtifactManager.scala:112)
	at org.apache.spark.sql.artifact.ArtifactManager$$Lambda/0x0000746af0bf4df8.apply(Unknown Source)


In [6]:
"""
Fixed Pure Spark SQL implementation for all-pairs shortest path calculation.
Key fixes:
1. Aggressive checkpointing every iteration to break lineage
2. Write intermediate results to disk to force materialization
3. No .cache() - use checkpoint instead
4. Simplified convergence check to avoid subtract() operations
"""

from pyspark.sql import SparkSession, DataFrame
from pyspark.sql import functions as F
import tempfile
import shutil


def compute_shortest_paths_for_group(
    df_group: DataFrame, 
    max_iterations: int = None,
    temp_dir: str = None
) -> DataFrame:
    """
    Computes all-pairs shortest paths using pure Spark SQL with aggressive lineage breaking.
    
    The key insight: We must materialize to disk (not just cache) to truly break lineage.
    
    Args:
        df_group: DataFrame with schema (src, dst, cost, next_node)
        max_iterations: Maximum iterations (default: row count)
        temp_dir: Temporary directory for checkpointing
    
    Returns:
        DataFrame with shortest paths
    """
    if max_iterations is None:
        max_iterations = df_group.count()
    
    # Create temp directory if not provided
    if temp_dir is None:
        temp_dir = tempfile.mkdtemp(prefix="spark_shortest_path_")
    
    # Initialize: rename columns
    current_paths = df_group.select(
        F.col("src").alias("path_start"),
        F.col("dst").alias("path_end"),
        F.col("cost"),
        F.col("next_node")
    )
    
    # Checkpoint immediately to break initial lineage
    current_paths = current_paths.checkpoint(eager=True)
    prev_count = current_paths.count()
    
    print(f"  Starting with {prev_count} paths")
    
    for iteration in range(1, max_iterations + 1):
        print(f"  Iteration {iteration}...", end=" ")
        
        # Step 1: Find new paths via self-join
        new_paths = current_paths.alias("L").join(
            current_paths.alias("R"),
            F.col("L.path_end") == F.col("R.path_start"),
            "inner"
        ).select(
            F.col("L.path_start").alias("path_start"),
            F.col("R.path_end").alias("path_end"),
            (F.col("L.cost") + F.col("R.cost")).alias("cost"),
            F.col("L.next_node").alias("next_node")
        ).filter(F.col("path_start") != F.col("path_end"))
        
        # Step 2: Union with existing paths
        all_paths = current_paths.unionByName(new_paths)
        
        # Step 3: Keep minimum cost per (start, end) pair
        # Use window function to avoid additional join
        from pyspark.sql.window import Window
        window = Window.partitionBy("path_start", "path_end").orderBy("cost")
        
        next_paths = all_paths.withColumn("rank", F.row_number().over(window)) \
                              .filter(F.col("rank") == 1) \
                              .drop("rank")
        
        # CRITICAL: Checkpoint EVERY iteration to break lineage
        # Use eager=True to force immediate materialization
        next_paths = next_paths.checkpoint(eager=True)
        
        # Count new paths
        next_count = next_paths.count()
        print(f"{next_count} paths")
        
        # Simple convergence check: if count didn't change, we're done
        if next_count == prev_count:
            print(f"  Converged! No new paths added.")
            # Rename back to original schema
            return next_paths.select(
                F.col("path_start").alias("src"),
                F.col("path_end").alias("dst"),
                F.col("cost"),
                F.col("next_node")
            )
        
        prev_count = next_count
        current_paths = next_paths
    
    print(f"  Reached max iterations ({max_iterations})")
    
    # Return with original column names
    return current_paths.select(
        F.col("path_start").alias("src"),
        F.col("path_end").alias("dst"),
        F.col("cost"),
        F.col("next_node")
    )


def compute_shortest_paths_per_partition(
    df_shortcuts: DataFrame,
    partition_columns: list,
    max_iterations_per_group: int = None
) -> DataFrame:
    """
    Computes shortest paths for each partition separately.
    Processes one partition at a time to avoid memory issues.
    
    Args:
        df_shortcuts: DataFrame with schema (src, dst, cost, next_node, partition_cols...)
        partition_columns: Columns to partition by
        max_iterations_per_group: Max iterations per group
    
    Returns:
        DataFrame with all shortest paths
    """
    
    # Get unique partition values
    partitions = df_shortcuts.select(partition_columns).distinct().collect()
    
    print(f"\nProcessing {len(partitions)} partitions...")
    
    results = []
    temp_dir = tempfile.mkdtemp(prefix="spark_partitions_")
    
    try:
        for idx, partition_row in enumerate(partitions):
            print(f"\n[Partition {idx + 1}/{len(partitions)}]")
            
            # Build filter condition for this partition
            partition_dict = partition_row.asDict()
            filter_condition = None
            for col, val in partition_dict.items():
                condition = (F.col(col) == F.lit(val))
                filter_condition = condition if filter_condition is None else (filter_condition & condition)
            
            # Extract this partition's data
            partition_df = df_shortcuts.filter(filter_condition)
            partition_size = partition_df.count()
            print(f"  Partition has {partition_size} edges")
            
            # Compute shortest paths for this partition
            shortest_paths = compute_shortest_paths_for_group(
                partition_df.select("src", "dst", "cost", "next_node"),
                max_iterations=max_iterations_per_group,
                temp_dir=temp_dir
            )
            
            # Add partition columns back
            for col, val in partition_dict.items():
                shortest_paths = shortest_paths.withColumn(col, F.lit(val))
            
            # Write this partition's result to temp location to break lineage
            partition_result_path = f"{temp_dir}/partition_{idx}"
            shortest_paths.write.mode("overwrite").parquet(partition_result_path)
            
            # Read back (fully breaks lineage)
            partition_result = df_shortcuts.sparkSession.read.parquet(partition_result_path)
            results.append(partition_result)
            
            print(f"  Completed: {partition_result.count()} shortest paths computed")
        
        # Union all partition results
        if not results:
            return df_shortcuts.limit(0)
        
        print(f"\nCombining {len(results)} partition results...")
        final_result = results[0]
        for result_df in results[1:]:
            final_result = final_result.unionByName(result_df)
        
        return final_result
        
    finally:
        # Cleanup temp directory
        try:
            shutil.rmtree(temp_dir)
        except:
            pass


def merge_shortcuts_to_main_table(
    main_df: DataFrame,
    new_shortcuts: DataFrame,
    partition_columns: list
) -> DataFrame:
    """
    Merges new shortcuts back into main table.
    Removes old shortcuts for recomputed partitions and adds new ones.
    """
    recomputed_partitions = new_shortcuts.select(partition_columns).distinct()
    
    # Keep shortcuts from partitions that weren't recomputed
    remaining_df = main_df.alias("main").join(
        recomputed_partitions.alias("recomp"),
        [F.col(f"main.{col}") == F.col(f"recomp.{col}") for col in partition_columns],
        "left_anti"
    )
    
    # Add new shortcuts
    updated_df = remaining_df.unionByName(new_shortcuts)
    
    return updated_df


def example_pipeline(spark: SparkSession):
    """Example usage demonstrating the fixed implementation."""
    
    # Create test data
    data = [
        ("A", "B", 1.0, "B", "region1"),
        ("B", "C", 2.0, "C", "region1"),
        ("C", "D", 1.0, "D", "region1"),
        ("A", "C", 5.0, "C", "region1"),  # Suboptimal path
        ("E", "F", 1.0, "F", "region2"),
        ("F", "G", 2.0, "G", "region2"),
    ]
    
    shortcuts = spark.createDataFrame(
        data, 
        ["src", "dst", "cost", "next_node", "region_id"]
    )
    
    print("Initial shortcuts:")
    shortcuts.show()
    
    # Iterative pipeline
    for iteration in range(3):
        print(f"\n{'='*70}")
        print(f"PIPELINE ITERATION {iteration + 1}")
        print(f"{'='*70}")
        
        # Filter
        filtered = shortcuts.filter(F.col("cost") < 10)
        print(f"\nFiltered: {filtered.count()} shortcuts")
        
        # Compute shortest paths per partition
        new_shortcuts = compute_shortest_paths_per_partition(
            filtered,
            partition_columns=["region_id"],
            max_iterations_per_group=20  # Reduced for safety
        )
        
        print(f"\nNew shortcuts computed: {new_shortcuts.count()}")
        new_shortcuts.show()
        
        # Merge back
        shortcuts = merge_shortcuts_to_main_table(
            shortcuts,
            new_shortcuts,
            partition_columns=["region_id"]
        )
        
        print(f"\nUpdated table: {shortcuts.count()} total shortcuts")
    
    return shortcuts


if __name__ == "__main__":
    # Configuration - still keep memory reasonable
    spark = SparkSession.builder \
        .appName("FixedPureSparkShortestPath") \
        .config("spark.driver.memory", "4g") \
        .config("spark.executor.memory", "4g") \
        .config("spark.sql.adaptive.enabled", "true") \
        .config("spark.sql.shuffle.partitions", "10") \
        .getOrCreate()
    
    # Set checkpoint directory
    import os
    checkpoint_dir = "checkpoints"
    os.makedirs(checkpoint_dir, exist_ok=True)
    spark.sparkContext.setCheckpointDir(checkpoint_dir)
    
    print("="*70)
    print("FIXED PURE SPARK SQL SHORTEST PATH CALCULATOR")
    print("="*70)
    
    result = example_pipeline(spark)
    
    print("\n" + "="*70)
    print("FINAL RESULT")
    print("="*70)
    result.show()
    
    spark.stop()

25/11/16 16:59:17 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.
                                                                                

FIXED PURE SPARK SQL SHORTEST PATH CALCULATOR
Initial shortcuts:
+---+---+----+---------+---------+
|src|dst|cost|next_node|region_id|
+---+---+----+---------+---------+
|  A|  B| 1.0|        B|  region1|
|  B|  C| 2.0|        C|  region1|
|  C|  D| 1.0|        D|  region1|
|  A|  C| 5.0|        C|  region1|
|  E|  F| 1.0|        F|  region2|
|  F|  G| 2.0|        G|  region2|
+---+---+----+---------+---------+


PIPELINE ITERATION 1

Filtered: 6 shortcuts

Processing 2 partitions...

[Partition 1/2]
  Partition has 4 edges
  Starting with 4 paths
  Iteration 1... 6 paths
  Iteration 2... 6 paths
  Converged! No new paths added.
  Completed: 6 shortest paths computed

[Partition 2/2]
  Partition has 2 edges
  Starting with 2 paths
  Iteration 1... 3 paths
  Iteration 2... 3 paths
  Converged! No new paths added.
  Completed: 3 shortest paths computed

Combining 2 partition results...


25/11/16 16:59:21 ERROR Executor: Exception in task 0.0 in stage 1391.0 (TID 54360)
org.apache.spark.SparkException: [FAILED_READ_FILE.FILE_NOT_EXIST] Encountered error while reading file file:///tmp/spark_partitions_5jxo1hf4/partition_0/part-00000-be0fbed6-869a-4dd9-bd93-016983a3c2ee-c000.snappy.parquet. File does not exist. It is possible the underlying files have been updated.
You can explicitly invalidate the cache in Spark by running 'REFRESH TABLE tableName' command in SQL or by recreating the Dataset/DataFrame involved. SQLSTATE: KD001
	at org.apache.spark.sql.errors.QueryExecutionErrors$.fileNotExistError(QueryExecutionErrors.scala:831)
	at org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2$.attachFilePath(FileDataSourceV2.scala:140)
	at org.apache.spark.sql.execution.datasources.FileScanRDD$$anon$1.hasNext(FileScanRDD.scala:142)
	at org.apache.spark.sql.execution.FileSourceScanExec$$anon$1.hasNext(DataSourceScanExec.scala:695)
	at org.apache.spark.sql.catalyst.expr

Py4JJavaError: An error occurred while calling o3331.count.
: org.apache.spark.SparkException: [FAILED_READ_FILE.FILE_NOT_EXIST] Encountered error while reading file file:///tmp/spark_partitions_5jxo1hf4/partition_0/part-00000-be0fbed6-869a-4dd9-bd93-016983a3c2ee-c000.snappy.parquet. File does not exist. It is possible the underlying files have been updated.
You can explicitly invalidate the cache in Spark by running 'REFRESH TABLE tableName' command in SQL or by recreating the Dataset/DataFrame involved. SQLSTATE: KD001
	at org.apache.spark.sql.errors.QueryExecutionErrors$.fileNotExistError(QueryExecutionErrors.scala:831)
	at org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2$.attachFilePath(FileDataSourceV2.scala:140)
	at org.apache.spark.sql.execution.datasources.FileScanRDD$$anon$1.hasNext(FileScanRDD.scala:142)
	at org.apache.spark.sql.execution.FileSourceScanExec$$anon$1.hasNext(DataSourceScanExec.scala:695)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.columnartorow_nextBatch_0$(Unknown Source)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.hashAgg_doAggregateWithoutKey_0$(Unknown Source)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenEvaluatorFactory$WholeStageCodegenPartitionEvaluator$$anon$1.hasNext(WholeStageCodegenEvaluatorFactory.scala:50)
	at scala.collection.Iterator$$anon$9.hasNext(Iterator.scala:583)
	at org.apache.spark.shuffle.sort.BypassMergeSortShuffleWriter.write(BypassMergeSortShuffleWriter.java:143)
	at org.apache.spark.shuffle.ShuffleWriteProcessor.write(ShuffleWriteProcessor.scala:57)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:111)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:54)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:171)
	at org.apache.spark.scheduler.Task.run(Task.scala:147)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$5(Executor.scala:647)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:80)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:77)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:99)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:650)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1144)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:642)
	at java.base/java.lang.Thread.run(Thread.java:1583)
Caused by: java.io.FileNotFoundException: File file:/tmp/spark_partitions_5jxo1hf4/partition_0/part-00000-be0fbed6-869a-4dd9-bd93-016983a3c2ee-c000.snappy.parquet does not exist
	at org.apache.hadoop.fs.RawLocalFileSystem.deprecatedGetFileStatus(RawLocalFileSystem.java:917)
	at org.apache.hadoop.fs.RawLocalFileSystem.getFileLinkStatusInternal(RawLocalFileSystem.java:1238)
	at org.apache.hadoop.fs.RawLocalFileSystem.getFileStatus(RawLocalFileSystem.java:907)
	at org.apache.hadoop.fs.FilterFileSystem.getFileStatus(FilterFileSystem.java:462)
	at org.apache.parquet.hadoop.util.HadoopInputFile.fromPath(HadoopInputFile.java:38)
	at org.apache.spark.sql.execution.datasources.parquet.ParquetFooterReader.readFooter(ParquetFooterReader.java:71)
	at org.apache.spark.sql.execution.datasources.parquet.ParquetFooterReader.readFooter(ParquetFooterReader.java:66)
	at org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat.$anonfun$buildReaderWithPartitionValues$2(ParquetFileFormat.scala:214)
	at org.apache.spark.sql.execution.datasources.FileScanRDD$$anon$1.org$apache$spark$sql$execution$datasources$FileScanRDD$$anon$$readCurrentFile(FileScanRDD.scala:230)
	at org.apache.spark.sql.execution.datasources.FileScanRDD$$anon$1.nextIterator(FileScanRDD.scala:289)
	at org.apache.spark.sql.execution.datasources.FileScanRDD$$anon$1.hasNext0(FileScanRDD.scala:131)
	at org.apache.spark.sql.execution.datasources.FileScanRDD$$anon$1.hasNext(FileScanRDD.scala:140)
	... 21 more


In [7]:
"""
Fixed Pure Spark SQL implementation for all-pairs shortest path calculation.
Key fixes:
1. Aggressive checkpointing every iteration to break lineage
2. Write intermediate results to disk to force materialization
3. No .cache() - use checkpoint instead
4. Simplified convergence check to avoid subtract() operations
"""

from pyspark.sql import SparkSession, DataFrame
from pyspark.sql import functions as F
import tempfile
import shutil


def compute_shortest_paths_for_group(
    df_group: DataFrame, 
    max_iterations: int = None,
    temp_dir: str = None
) -> DataFrame:
    """
    Computes all-pairs shortest paths using pure Spark SQL with aggressive lineage breaking.
    
    The key insight: We must materialize to disk (not just cache) to truly break lineage.
    
    Args:
        df_group: DataFrame with schema (src, dst, cost, next_node)
        max_iterations: Maximum iterations (default: row count)
        temp_dir: Temporary directory for checkpointing
    
    Returns:
        DataFrame with shortest paths
    """
    if max_iterations is None:
        max_iterations = df_group.count()
    
    # Create temp directory if not provided
    if temp_dir is None:
        temp_dir = tempfile.mkdtemp(prefix="spark_shortest_path_")
    
    # Initialize: rename columns
    current_paths = df_group.select(
        F.col("src").alias("path_start"),
        F.col("dst").alias("path_end"),
        F.col("cost"),
        F.col("next_node")
    )
    
    # Checkpoint immediately to break initial lineage
    current_paths = current_paths.checkpoint(eager=True)
    prev_count = current_paths.count()
    
    print(f"  Starting with {prev_count} paths")
    
    for iteration in range(1, max_iterations + 1):
        print(f"  Iteration {iteration}...", end=" ")
        
        # Step 1: Find new paths via self-join
        new_paths = current_paths.alias("L").join(
            current_paths.alias("R"),
            F.col("L.path_end") == F.col("R.path_start"),
            "inner"
        ).select(
            F.col("L.path_start").alias("path_start"),
            F.col("R.path_end").alias("path_end"),
            (F.col("L.cost") + F.col("R.cost")).alias("cost"),
            F.col("L.next_node").alias("next_node")
        ).filter(F.col("path_start") != F.col("path_end"))
        
        # Step 2: Union with existing paths
        all_paths = current_paths.unionByName(new_paths)
        
        # Step 3: Keep minimum cost per (start, end) pair
        # Use window function to avoid additional join
        from pyspark.sql.window import Window
        window = Window.partitionBy("path_start", "path_end").orderBy("cost")
        
        next_paths = all_paths.withColumn("rank", F.row_number().over(window)) \
                              .filter(F.col("rank") == 1) \
                              .drop("rank")
        
        # CRITICAL: Checkpoint EVERY iteration to break lineage
        # Use eager=True to force immediate materialization
        next_paths = next_paths.checkpoint(eager=True)
        
        # Count new paths
        next_count = next_paths.count()
        print(f"{next_count} paths")
        
        # Simple convergence check: if count didn't change, we're done
        if next_count == prev_count:
            print(f"  Converged! No new paths added.")
            # Rename back to original schema
            return next_paths.select(
                F.col("path_start").alias("src"),
                F.col("path_end").alias("dst"),
                F.col("cost"),
                F.col("next_node")
            )
        
        prev_count = next_count
        current_paths = next_paths
    
    print(f"  Reached max iterations ({max_iterations})")
    
    # Return with original column names
    return current_paths.select(
        F.col("path_start").alias("src"),
        F.col("path_end").alias("dst"),
        F.col("cost"),
        F.col("next_node")
    )


def compute_shortest_paths_per_partition(
    df_shortcuts: DataFrame,
    partition_columns: list,
    max_iterations_per_group: int = None
) -> DataFrame:
    """
    Computes shortest paths for each partition separately.
    Processes one partition at a time to avoid memory issues.
    
    Args:
        df_shortcuts: DataFrame with schema (src, dst, cost, next_node, partition_cols...)
        partition_columns: Columns to partition by
        max_iterations_per_group: Max iterations per group
    
    Returns:
        DataFrame with all shortest paths
    """
    
    # Get unique partition values
    partitions = df_shortcuts.select(partition_columns).distinct().collect()
    
    print(f"\nProcessing {len(partitions)} partitions...")
    
    results = []
    temp_dir = tempfile.mkdtemp(prefix="spark_partitions_")
    
    for idx, partition_row in enumerate(partitions):
        print(f"\n[Partition {idx + 1}/{len(partitions)}]")
        
        # Build filter condition for this partition
        partition_dict = partition_row.asDict()
        filter_condition = None
        for col, val in partition_dict.items():
            condition = (F.col(col) == F.lit(val))
            filter_condition = condition if filter_condition is None else (filter_condition & condition)
        
        # Extract this partition's data
        partition_df = df_shortcuts.filter(filter_condition)
        partition_size = partition_df.count()
        print(f"  Partition has {partition_size} edges")
        
        # Compute shortest paths for this partition
        shortest_paths = compute_shortest_paths_for_group(
            partition_df.select("src", "dst", "cost", "next_node"),
            max_iterations=max_iterations_per_group,
            temp_dir=temp_dir
        )
        
        # Add partition columns back
        for col, val in partition_dict.items():
            shortest_paths = shortest_paths.withColumn(col, F.lit(val))
        
        # Write this partition's result to temp location to break lineage
        partition_result_path = f"{temp_dir}/partition_{idx}"
        shortest_paths.write.mode("overwrite").parquet(partition_result_path)
        
        # Read back (fully breaks lineage)
        partition_result = df_shortcuts.sparkSession.read.parquet(partition_result_path)
        results.append(partition_result)
        
        print(f"  Completed: {partition_result.count()} shortest paths computed")
    
    # Union all partition results
    if not results:
        return df_shortcuts.limit(0)
    
    print(f"\nCombining {len(results)} partition results...")
    final_result = results[0]
    for result_df in results[1:]:
        final_result = final_result.unionByName(result_df)
    
    # CRITICAL: Materialize the union before cleanup
    # Write to a final location and read back
    final_output_path = f"{temp_dir}/final_result"
    final_result.write.mode("overwrite").parquet(final_output_path)
    final_result = df_shortcuts.sparkSession.read.parquet(final_output_path)
    
    # Now safe to cleanup temp directory
    try:
        shutil.rmtree(temp_dir)
        print(f"Cleaned up temp directory: {temp_dir}")
    except Exception as e:
        print(f"Warning: Could not clean up temp directory {temp_dir}: {e}")
    
    return final_result


def merge_shortcuts_to_main_table(
    main_df: DataFrame,
    new_shortcuts: DataFrame,
    partition_columns: list
) -> DataFrame:
    """
    Merges new shortcuts back into main table.
    Removes old shortcuts for recomputed partitions and adds new ones.
    """
    recomputed_partitions = new_shortcuts.select(partition_columns).distinct()
    
    # Keep shortcuts from partitions that weren't recomputed
    remaining_df = main_df.alias("main").join(
        recomputed_partitions.alias("recomp"),
        [F.col(f"main.{col}") == F.col(f"recomp.{col}") for col in partition_columns],
        "left_anti"
    )
    
    # Add new shortcuts
    updated_df = remaining_df.unionByName(new_shortcuts)
    
    return updated_df


def example_pipeline(spark: SparkSession):
    """Example usage demonstrating the fixed implementation."""
    
    # Create test data
    data = [
        ("A", "B", 1.0, "B", "region1"),
        ("B", "C", 2.0, "C", "region1"),
        ("C", "D", 1.0, "D", "region1"),
        ("A", "C", 5.0, "C", "region1"),  # Suboptimal path
        ("E", "F", 1.0, "F", "region2"),
        ("F", "G", 2.0, "G", "region2"),
    ]
    
    shortcuts = spark.createDataFrame(
        data, 
        ["src", "dst", "cost", "next_node", "region_id"]
    )
    
    print("Initial shortcuts:")
    shortcuts.show()
    
    # Iterative pipeline
    for iteration in range(3):
        print(f"\n{'='*70}")
        print(f"PIPELINE ITERATION {iteration + 1}")
        print(f"{'='*70}")
        
        # Filter
        filtered = shortcuts.filter(F.col("cost") < 10)
        print(f"\nFiltered: {filtered.count()} shortcuts")
        
        # Compute shortest paths per partition
        new_shortcuts = compute_shortest_paths_per_partition(
            filtered,
            partition_columns=["region_id"],
            max_iterations_per_group=20  # Reduced for safety
        )
        
        print(f"\nNew shortcuts computed: {new_shortcuts.count()}")
        new_shortcuts.show()
        
        # Merge back
        shortcuts = merge_shortcuts_to_main_table(
            shortcuts,
            new_shortcuts,
            partition_columns=["region_id"]
        )
        
        print(f"\nUpdated table: {shortcuts.count()} total shortcuts")
    
    return shortcuts


if __name__ == "__main__":
    # Configuration - still keep memory reasonable
    spark = SparkSession.builder \
        .appName("FixedPureSparkShortestPath") \
        .config("spark.driver.memory", "4g") \
        .config("spark.executor.memory", "4g") \
        .config("spark.sql.adaptive.enabled", "true") \
        .config("spark.sql.shuffle.partitions", "10") \
        .getOrCreate()
    
    # Set checkpoint directory
    import os
    checkpoint_dir = "checkpoints"
    os.makedirs(checkpoint_dir, exist_ok=True)
    spark.sparkContext.setCheckpointDir(checkpoint_dir)
    
    print("="*70)
    print("FIXED PURE SPARK SQL SHORTEST PATH CALCULATOR")
    print("="*70)
    
    result = example_pipeline(spark)
    
    print("\n" + "="*70)
    print("FINAL RESULT")
    print("="*70)
    result.show()
    
    spark.stop()

FIXED PURE SPARK SQL SHORTEST PATH CALCULATOR
Initial shortcuts:
+---+---+----+---------+---------+
|src|dst|cost|next_node|region_id|
+---+---+----+---------+---------+
|  A|  B| 1.0|        B|  region1|
|  B|  C| 2.0|        C|  region1|
|  C|  D| 1.0|        D|  region1|
|  A|  C| 5.0|        C|  region1|
|  E|  F| 1.0|        F|  region2|
|  F|  G| 2.0|        G|  region2|
+---+---+----+---------+---------+


PIPELINE ITERATION 1


                                                                                


Filtered: 6 shortcuts

Processing 2 partitions...

[Partition 1/2]
  Partition has 4 edges
  Starting with 4 paths
  Iteration 1... 6 paths
  Iteration 2... 6 paths
  Converged! No new paths added.
  Completed: 6 shortest paths computed

[Partition 2/2]
  Partition has 2 edges


                                                                                

  Starting with 2 paths
  Iteration 1... 3 paths
  Iteration 2... 3 paths
  Converged! No new paths added.
  Completed: 3 shortest paths computed

Combining 2 partition results...
Cleaned up temp directory: /tmp/spark_partitions_3rnxn6u2


25/11/16 17:01:10 ERROR Executor: Exception in task 1.0 in stage 1475.0 (TID 54706)
org.apache.spark.SparkException: [FAILED_READ_FILE.FILE_NOT_EXIST] Encountered error while reading file file:///tmp/spark_partitions_3rnxn6u2/final_result/part-00001-79a4a7ad-091e-48cf-a971-bd814a29ac3e-c000.snappy.parquet. File does not exist. It is possible the underlying files have been updated.
You can explicitly invalidate the cache in Spark by running 'REFRESH TABLE tableName' command in SQL or by recreating the Dataset/DataFrame involved. SQLSTATE: KD001
	at org.apache.spark.sql.errors.QueryExecutionErrors$.fileNotExistError(QueryExecutionErrors.scala:831)
	at org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2$.attachFilePath(FileDataSourceV2.scala:140)
	at org.apache.spark.sql.execution.datasources.FileScanRDD$$anon$1.hasNext(FileScanRDD.scala:142)
	at org.apache.spark.sql.execution.FileSourceScanExec$$anon$1.hasNext(DataSourceScanExec.scala:695)
	at org.apache.spark.sql.catalyst.exp

Py4JJavaError: An error occurred while calling o3726.count.
: org.apache.spark.SparkException: [FAILED_READ_FILE.FILE_NOT_EXIST] Encountered error while reading file file:///tmp/spark_partitions_3rnxn6u2/final_result/part-00000-79a4a7ad-091e-48cf-a971-bd814a29ac3e-c000.snappy.parquet. File does not exist. It is possible the underlying files have been updated.
You can explicitly invalidate the cache in Spark by running 'REFRESH TABLE tableName' command in SQL or by recreating the Dataset/DataFrame involved. SQLSTATE: KD001
	at org.apache.spark.sql.errors.QueryExecutionErrors$.fileNotExistError(QueryExecutionErrors.scala:831)
	at org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2$.attachFilePath(FileDataSourceV2.scala:140)
	at org.apache.spark.sql.execution.datasources.FileScanRDD$$anon$1.hasNext(FileScanRDD.scala:142)
	at org.apache.spark.sql.execution.FileSourceScanExec$$anon$1.hasNext(DataSourceScanExec.scala:695)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.columnartorow_nextBatch_0$(Unknown Source)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.hashAgg_doAggregateWithoutKey_0$(Unknown Source)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenEvaluatorFactory$WholeStageCodegenPartitionEvaluator$$anon$1.hasNext(WholeStageCodegenEvaluatorFactory.scala:50)
	at scala.collection.Iterator$$anon$9.hasNext(Iterator.scala:583)
	at org.apache.spark.shuffle.sort.BypassMergeSortShuffleWriter.write(BypassMergeSortShuffleWriter.java:143)
	at org.apache.spark.shuffle.ShuffleWriteProcessor.write(ShuffleWriteProcessor.scala:57)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:111)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:54)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:171)
	at org.apache.spark.scheduler.Task.run(Task.scala:147)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$5(Executor.scala:647)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:80)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:77)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:99)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:650)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1144)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:642)
	at java.base/java.lang.Thread.run(Thread.java:1583)
Caused by: java.io.FileNotFoundException: File file:/tmp/spark_partitions_3rnxn6u2/final_result/part-00000-79a4a7ad-091e-48cf-a971-bd814a29ac3e-c000.snappy.parquet does not exist
	at org.apache.hadoop.fs.RawLocalFileSystem.deprecatedGetFileStatus(RawLocalFileSystem.java:917)
	at org.apache.hadoop.fs.RawLocalFileSystem.getFileLinkStatusInternal(RawLocalFileSystem.java:1238)
	at org.apache.hadoop.fs.RawLocalFileSystem.getFileStatus(RawLocalFileSystem.java:907)
	at org.apache.hadoop.fs.FilterFileSystem.getFileStatus(FilterFileSystem.java:462)
	at org.apache.parquet.hadoop.util.HadoopInputFile.fromPath(HadoopInputFile.java:38)
	at org.apache.spark.sql.execution.datasources.parquet.ParquetFooterReader.readFooter(ParquetFooterReader.java:71)
	at org.apache.spark.sql.execution.datasources.parquet.ParquetFooterReader.readFooter(ParquetFooterReader.java:66)
	at org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat.$anonfun$buildReaderWithPartitionValues$2(ParquetFileFormat.scala:214)
	at org.apache.spark.sql.execution.datasources.FileScanRDD$$anon$1.org$apache$spark$sql$execution$datasources$FileScanRDD$$anon$$readCurrentFile(FileScanRDD.scala:230)
	at org.apache.spark.sql.execution.datasources.FileScanRDD$$anon$1.nextIterator(FileScanRDD.scala:289)
	at org.apache.spark.sql.execution.datasources.FileScanRDD$$anon$1.hasNext0(FileScanRDD.scala:131)
	at org.apache.spark.sql.execution.datasources.FileScanRDD$$anon$1.hasNext(FileScanRDD.scala:140)
	... 21 more


In [10]:
"""
Fixed Pure Spark SQL implementation for all-pairs shortest path calculation.
Key fixes:
1. Aggressive checkpointing with eager=True to break lineage
2. Use Spark's checkpoint mechanism instead of manual file operations
3. Simplified convergence check
"""

from pyspark.sql import SparkSession, DataFrame
from pyspark.sql import functions as F


def compute_shortest_paths_for_group(
    df_group: DataFrame, 
    max_iterations: int = None
) -> DataFrame:
    """
    Computes all-pairs shortest paths using pure Spark SQL with aggressive lineage breaking.
    
    Uses eager checkpointing every iteration to prevent execution plan explosion.
    
    Args:
        df_group: DataFrame with schema (src, dst, cost, next_node)
        max_iterations: Maximum iterations (default: row count)
    
    Returns:
        DataFrame with shortest paths
    """
    if max_iterations is None:
        max_iterations = df_group.count()
    
    # Initialize: rename columns
    current_paths = df_group.select(
        F.col("src").alias("path_start"),
        F.col("dst").alias("path_end"),
        F.col("cost"),
        F.col("next_node")
    )
    
    # Checkpoint immediately to break initial lineage
    current_paths = current_paths.checkpoint(eager=True)
    prev_count = current_paths.count()
    
    print(f"  Starting with {prev_count} paths")
    
    for iteration in range(1, max_iterations + 1):
        print(f"  Iteration {iteration}...", end=" ")
        
        # Step 1: Find new paths via self-join
        new_paths = current_paths.alias("L").join(
            current_paths.alias("R"),
            F.col("L.path_end") == F.col("R.path_start"),
            "inner"
        ).select(
            F.col("L.path_start").alias("path_start"),
            F.col("R.path_end").alias("path_end"),
            (F.col("L.cost") + F.col("R.cost")).alias("cost"),
            F.col("L.next_node").alias("next_node")
        ).filter(F.col("path_start") != F.col("path_end"))
        
        # Step 2: Union with existing paths
        all_paths = current_paths.unionByName(new_paths)
        
        # Step 3: Keep minimum cost per (start, end) pair
        # Use window function to avoid additional join
        from pyspark.sql.window import Window
        window = Window.partitionBy("path_start", "path_end").orderBy("cost")
        
        next_paths = all_paths.withColumn("rank", F.row_number().over(window)) \
                              .filter(F.col("rank") == 1) \
                              .drop("rank")
        
        # CRITICAL: Checkpoint EVERY iteration to break lineage
        # Use eager=True to force immediate materialization
        next_paths = next_paths.checkpoint(eager=True)
        
        # Count new paths
        next_count = next_paths.count()
        print(f"{next_count} paths")
        
        # Simple convergence check: if count didn't change, we're done
        if next_count == prev_count:
            print(f"  Converged! No new paths added.")
            # Rename back to original schema
            return next_paths.select(
                F.col("path_start").alias("src"),
                F.col("path_end").alias("dst"),
                F.col("cost"),
                F.col("next_node")
            )
        
        prev_count = next_count
        current_paths = next_paths
    
    print(f"  Reached max iterations ({max_iterations})")
    
    # Return with original column names
    return current_paths.select(
        F.col("path_start").alias("src"),
        F.col("path_end").alias("dst"),
        F.col("cost"),
        F.col("next_node")
    )


def compute_shortest_paths_per_partition(
    df_shortcuts: DataFrame,
    partition_columns: list,
    max_iterations_per_group: int = None,
    intermediate_path: str = None
) -> DataFrame:
    """
    Computes shortest paths for each partition separately.
    Processes one partition at a time to avoid memory issues.
    
    Args:
        df_shortcuts: DataFrame with schema (src, dst, cost, next_node, partition_cols...)
        partition_columns: Columns to partition by
        max_iterations_per_group: Max iterations per group
        intermediate_path: Path to store intermediate results (will be cleaned up after)
    
    Returns:
        DataFrame with all shortest paths
    """
    
    # Get unique partition values
    partitions = df_shortcuts.select(partition_columns).distinct().collect()
    
    print(f"\nProcessing {len(partitions)} partitions...")
    
    # Use provided path or create one
    if intermediate_path is None:
        intermediate_path = f"temp_shortest_paths_{id(df_shortcuts)}"
    
    results = []
    
    for idx, partition_row in enumerate(partitions):
        print(f"\n[Partition {idx + 1}/{len(partitions)}]")
        
        # Build filter condition for this partition
        partition_dict = partition_row.asDict()
        filter_condition = None
        for col, val in partition_dict.items():
            condition = (F.col(col) == F.lit(val))
            filter_condition = condition if filter_condition is None else (filter_condition & condition)
        
        # Extract this partition's data
        partition_df = df_shortcuts.filter(filter_condition)
        partition_size = partition_df.count()
        print(f"  Partition has {partition_size} edges")
        
        # Compute shortest paths for this partition
        shortest_paths = compute_shortest_paths_for_group(
            partition_df.select("src", "dst", "cost", "next_node"),
            max_iterations=max_iterations_per_group
        )
        
        # Add partition columns back
        for col, val in partition_dict.items():
            shortest_paths = shortest_paths.withColumn(col, F.lit(val))
        
        # Checkpoint this partition's result to break lineage completely
        shortest_paths = shortest_paths.checkpoint(eager=True)
        
        results.append(shortest_paths)
        
        print(f"  Completed: {shortest_paths.count()} shortest paths computed")
    
    # Union all partition results
    if not results:
        return df_shortcuts.limit(0)
    
    print(f"\nCombining {len(results)} partition results...")
    final_result = results[0]
    for result_df in results[1:]:
        final_result = final_result.unionByName(result_df)
    
    # Checkpoint the final union to break lineage
    final_result = final_result.checkpoint(eager=True)
    
    return final_result


def merge_shortcuts_to_main_table(
    main_df: DataFrame,
    new_shortcuts: DataFrame,
    partition_columns: list
) -> DataFrame:
    """
    Merges new shortcuts back into main table.
    Removes old shortcuts for recomputed partitions and adds new ones.
    """
    recomputed_partitions = new_shortcuts.select(partition_columns).distinct()
    
    # Keep shortcuts from partitions that weren't recomputed
    remaining_df = main_df.alias("main").join(
        recomputed_partitions.alias("recomp"),
        [F.col(f"main.{col}") == F.col(f"recomp.{col}") for col in partition_columns],
        "left_anti"
    )
    
    # Add new shortcuts
    updated_df = remaining_df.unionByName(new_shortcuts)
    
    return updated_df


def example_pipeline(spark: SparkSession):
    """Example usage demonstrating the fixed implementation."""
    
    # Create test data
    data = [
        ("A", "B", 1.0, "B", "region1"),
        ("B", "C", 2.0, "C", "region1"),
        ("C", "D", 1.0, "D", "region1"),
        ("A", "C", 5.0, "C", "region1"),  # Suboptimal path
        ("E", "F", 1.0, "F", "region2"),
        ("F", "G", 2.0, "G", "region2"),
    ]
    
    shortcuts = spark.createDataFrame(
        data, 
        ["src", "dst", "cost", "next_node", "region_id"]
    )
    
    print("Initial shortcuts:")
    shortcuts.show()
    
    # Iterative pipeline
    for iteration in range(3):
        print(f"\n{'='*70}")
        print(f"PIPELINE ITERATION {iteration + 1}")
        print(f"{'='*70}")
        
        # Filter
        filtered = shortcuts.filter(F.col("cost") < 10)
        print(f"\nFiltered: {filtered.count()} shortcuts")
        
        # Compute shortest paths per partition
        new_shortcuts = compute_shortest_paths_per_partition(
            filtered,
            partition_columns=["region_id"],
            max_iterations_per_group=20  # Reduced for safety
        )
        
        print(f"\nNew shortcuts computed: {new_shortcuts.count()}")
        new_shortcuts.show()
        
        # Merge back
        shortcuts = merge_shortcuts_to_main_table(
            shortcuts,
            new_shortcuts,
            partition_columns=["region_id"]
        )
        
        print(f"\nUpdated table: {shortcuts.count()} total shortcuts")
    
    return shortcuts


if __name__ == "__main__":
    # Configuration - still keep memory reasonable
    spark = SparkSession.builder \
        .appName("FixedPureSparkShortestPath") \
        .config("spark.driver.memory", "4g") \
        .config("spark.executor.memory", "4g") \
        .config("spark.sql.adaptive.enabled", "true") \
        .config("spark.sql.shuffle.partitions", "10") \
        .getOrCreate()
    
    # Set checkpoint directory
    import os
    checkpoint_dir = "checkpoints"
    os.makedirs(checkpoint_dir, exist_ok=True)
    spark.sparkContext.setCheckpointDir(checkpoint_dir)
    
    print("="*70)
    print("FIXED PURE SPARK SQL SHORTEST PATH CALCULATOR")
    print("="*70)
    
    result = example_pipeline(spark)
    
    print("\n" + "="*70)
    print("FINAL RESULT")
    print("="*70)
    result.show()
    
    spark.stop()

FIXED PURE SPARK SQL SHORTEST PATH CALCULATOR
Initial shortcuts:
+---+---+----+---------+---------+
|src|dst|cost|next_node|region_id|
+---+---+----+---------+---------+
|  A|  B| 1.0|        B|  region1|
|  B|  C| 2.0|        C|  region1|
|  C|  D| 1.0|        D|  region1|
|  A|  C| 5.0|        C|  region1|
|  E|  F| 1.0|        F|  region2|
|  F|  G| 2.0|        G|  region2|
+---+---+----+---------+---------+


PIPELINE ITERATION 1

Filtered: 6 shortcuts


                                                                                


Processing 2 partitions...

[Partition 1/2]
  Partition has 4 edges
  Starting with 4 paths
  Iteration 1... 6 paths
  Iteration 2... 6 paths
  Converged! No new paths added.
  Completed: 6 shortest paths computed

[Partition 2/2]
  Partition has 2 edges
  Starting with 2 paths
  Iteration 1... 3 paths
  Iteration 2... 3 paths
  Converged! No new paths added.
  Completed: 3 shortest paths computed

Combining 2 partition results...

New shortcuts computed: 9
+---+---+----+---------+---------+
|src|dst|cost|next_node|region_id|
+---+---+----+---------+---------+
|  A|  B| 1.0|        B|  region1|
|  A|  C| 3.0|        B|  region1|
|  A|  D| 4.0|        B|  region1|
|  B|  C| 2.0|        C|  region1|
|  B|  D| 3.0|        C|  region1|
|  C|  D| 1.0|        D|  region1|
|  E|  F| 1.0|        F|  region2|
|  E|  G| 3.0|        F|  region2|
|  F|  G| 2.0|        G|  region2|
+---+---+----+---------+---------+


Updated table: 9 total shortcuts

PIPELINE ITERATION 2

Filtered: 9 shortcuts

P