In [1]:
import os
os.environ['JAVA_HOME'] = '/usr/lib/jvm/java-21-openjdk-amd64'

In [2]:
"""
Pandas-optimized implementation for all-pairs shortest path calculation.
Best for: Small to medium groups (10-500 rows per partition).
Uses vectorized pandas operations within each partition via applyInPandas.
"""

from pyspark.sql import SparkSession, DataFrame
from pyspark.sql import functions as F
from pyspark.sql.types import StructType, StructField, StringType, DoubleType


def compute_shortest_paths_per_partition(
    df_shortcuts: DataFrame,
    partition_columns: list,
    max_iterations_per_group: int = None
) -> DataFrame:
    """
    Computes all-pairs shortest paths using vectorized pandas operations.
    Uses Spark's applyInPandas to process each partition with optimized pandas code.
    
    This approach is significantly faster than pure Spark SQL for small groups
    because it avoids shuffle operations and uses pandas' C-optimized operations.
    
    Args:
        df_shortcuts: DataFrame with schema (src, dst, cost, next_node, partition_col1, ...)
        partition_columns: List of column names to group by
        max_iterations_per_group: Maximum iterations per group (default: group size)
    
    Returns:
        DataFrame with computed shortest paths, preserving partition columns
    """
    
    # Infer the schema for output
    # We need to know the data types of partition columns
    partition_fields = [
        StructField(col, df_shortcuts.schema[col].dataType, False) 
        for col in partition_columns
    ]
    
    output_schema = StructType([
        StructField("src", StringType(), False),
        StructField("dst", StringType(), False),
        StructField("cost", DoubleType(), False),
        StructField("next_node", StringType(), False)
    ] + partition_fields)
    
    def process_partition_pandas(pdf):
        """
        Process a single partition using vectorized pandas operations.
        This function runs inside each Spark executor on the partition's data.
        """
        import pandas as pd
        
        if len(pdf) == 0:
            return pd.DataFrame(columns=['src', 'dst', 'cost', 'next_node'] + partition_columns)
        
        # Store partition values (all rows have same partition values)
        partition_values = {col: pdf[col].iloc[0] for col in partition_columns}
        
        # Start with the input edges as initial paths
        paths = pdf[['src', 'dst', 'cost', 'next_node']].copy()
        
        # Remove duplicates, keeping minimum cost
        # No need to remove duplicates
        paths = paths.loc[paths.groupby(['src', 'dst'])['cost'].idxmin()].reset_index(drop=True)
        
        # Set max iterations
        max_iters = max_iterations_per_group if max_iterations_per_group else len(pdf)
        
        for iteration in range(max_iters):
            # Vectorized merge: if A->B and B->C exist, create A->C
            # This is equivalent to the self-join in Spark but uses pandas merge
            new_paths = paths.merge(
                paths,
                left_on='dst',
                right_on='src',
                suffixes=('_L', '_R')
            )
            
            # Calculate new costs and keep the first hop's next_node
            new_paths = new_paths[['src_L', 'dst_R', 'cost_L', 'cost_R', 'next_node_L']]
            new_paths['cost'] = new_paths['cost_L'] + new_paths['cost_R']
            new_paths = new_paths.rename(columns={
                'src_L': 'src',
                'dst_R': 'dst',
                'next_node_L': 'next_node'
            })[['src', 'dst', 'cost', 'next_node']]
            
            # Filter out self-loops
            new_paths = new_paths[new_paths['src'] != new_paths['dst']]
            
            if len(new_paths) == 0:
                break  # No new paths to add
            
            # Combine existing and new paths
            combined = pd.concat([paths, new_paths], ignore_index=True)
            
            # Keep only minimum cost for each (src, dst) pair - vectorized operation
            updated_paths = combined.loc[
                combined.groupby(['src', 'dst'])['cost'].idxmin()
            ].reset_index(drop=True)
            
            # Check convergence: if no change in paths, we're done
            if len(updated_paths) == len(paths):
                # Quick check: if sizes match, check if content is identical
                merged = updated_paths.merge(
                    paths, 
                    on=['src', 'dst', 'cost', 'next_node']
                )
                if len(merged) == len(paths):
                    break  # Converged
            
            paths = updated_paths
        
        # Add partition columns back to result
        for col, val in partition_values.items():
            paths[col] = val
        
        return paths
    
    # Apply the pandas function to each partition group
    result = df_shortcuts.groupBy(partition_columns).applyInPandas(
        process_partition_pandas,
        schema=output_schema
    )
    
    return result


def merge_shortcuts_to_main_table(
    main_df: DataFrame,
    new_shortcuts: DataFrame,
    partition_columns: list
) -> DataFrame:
    """
    Merges newly computed shortcuts back into the main table.
    
    Strategy:
    1. Remove old shortcuts for the processed partitions
    2. Add new shortcuts
    
    Args:
        main_df: Main shortcut table
        new_shortcuts: Newly computed shortcuts
        partition_columns: Columns used for partitioning
    
    Returns:
        Updated DataFrame with new shortcuts
    """
    # Get the partition values that were recomputed
    recomputed_partitions = new_shortcuts.select(partition_columns).distinct()
    
    # Remove old shortcuts for these partitions using left_anti join
    #
    # insteed of using this join function, we can just use the part of main_df that filtered out
    remaining_df = main_df.alias("main").join(
        recomputed_partitions.alias("recomp"),
        [F.col(f"main.{col}") == F.col(f"recomp.{col}") for col in partition_columns],
        "left_anti"
    )
    
    # Combine with new shortcuts
    updated_df = remaining_df.unionByName(new_shortcuts)
    
    return updated_df


def example_pipeline(spark: SparkSession):
    """
    Example usage in an iterative pipeline.
    """
    
    # Example: Create initial shortcuts table
    data = [
        ("A", "B", 1.0, "B", "region1"),
        ("B", "C", 2.0, "C", "region1"),
        ("C", "D", 1.0, "D", "region1"),
        ("A", "C", 5.0, "C", "region1"),  # Suboptimal shortcut
        ("E", "F", 1.0, "F", "region2"),
        ("F", "G", 2.0, "G", "region2"),
    ]
    
    shortcuts = spark.createDataFrame(
        data, 
        ["src", "dst", "cost", "next_node", "region_id"]
    )
    
    print("Initial shortcuts:")
    shortcuts.show()
    
    # Iterative pipeline
    for iteration in range(3):
        print(f"\n{'='*60}")
        print(f"ITERATION {iteration + 1}")
        print(f"{'='*60}")
        
        # Step 1: Filter (example: keep only cost < 10)
        filtered = shortcuts.filter(F.col("cost") < 10)
        
        print(f"\nFiltered shortcuts (cost < 10):")
        filtered.show()
        
        # Step 2: Compute shortest paths per partition using pandas
        new_shortcuts = compute_shortest_paths_per_partition(
            filtered,
            partition_columns=["region_id"],
            max_iterations_per_group=100
        )
        
        print(f"\nNew shortcuts computed:")
        new_shortcuts.show()
        
        # Step 3: Merge back to main table
        shortcuts = merge_shortcuts_to_main_table(
            shortcuts,
            new_shortcuts,
            partition_columns=["region_id"]
        )
        
        print(f"\nUpdated shortcuts table:")
        shortcuts.show()
        
        # Check if converged (optional)
        shortcut_count = shortcuts.count()
        print(f"Total shortcuts: {shortcut_count}")
    
    return shortcuts


def benchmark_comparison(spark: SparkSession):
    """
    Simple benchmark to compare convergence speed.
    """
    import time
    
    # Create a slightly larger example
    data = []
    for i in range(20):
        for j in range(i+1, min(i+5, 20)):
            data.append((f"N{i}", f"N{j}", float(j-i), f"N{j}", "partition1"))
    
    df = spark.createDataFrame(data, ["src", "dst", "cost", "next_node", "region_id"])
    
    print(f"Input: {df.count()} edges")
    
    start = time.time()
    result = compute_shortest_paths_per_partition(
        df,
        partition_columns=["region_id"],
        max_iterations_per_group=50
    )
    result.cache()
    result_count = result.count()
    end = time.time()
    
    print(f"\nResult: {result_count} shortest paths")
    print(f"Time taken: {end - start:.2f} seconds")
    
    return result


if __name__ == "__main__":
    spark = SparkSession.builder \
        .appName("PandasShortestPath") \
        .config("spark.driver.memory", "4g") \
        .getOrCreate()
    
    print("="*60)
    print("PANDAS-OPTIMIZED SHORTEST PATH CALCULATOR")
    print("="*60)
    
    # Run example pipeline
    result = example_pipeline(spark)
    
    print("\n" + "="*60)
    print("FINAL RESULT")
    print("="*60)
    result.show()
    
    print("\n" + "="*60)
    print("BENCHMARK")
    print("="*60)
    benchmark_result = benchmark_comparison(spark)
    
    spark.stop()

Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
25/11/17 00:06:10 WARN Utils: Your hostname, Bamdad-Beast, resolves to a loopback address: 127.0.1.1; using 10.255.255.254 instead (on interface lo)
25/11/17 00:06:10 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/11/17 00:06:11 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


PANDAS-OPTIMIZED SHORTEST PATH CALCULATOR
Initial shortcuts:


                                                                                

+---+---+----+---------+---------+
|src|dst|cost|next_node|region_id|
+---+---+----+---------+---------+
|  A|  B| 1.0|        B|  region1|
|  B|  C| 2.0|        C|  region1|
|  C|  D| 1.0|        D|  region1|
|  A|  C| 5.0|        C|  region1|
|  E|  F| 1.0|        F|  region2|
|  F|  G| 2.0|        G|  region2|
+---+---+----+---------+---------+


ITERATION 1

Filtered shortcuts (cost < 10):
+---+---+----+---------+---------+
|src|dst|cost|next_node|region_id|
+---+---+----+---------+---------+
|  A|  B| 1.0|        B|  region1|
|  B|  C| 2.0|        C|  region1|
|  C|  D| 1.0|        D|  region1|
|  A|  C| 5.0|        C|  region1|
|  E|  F| 1.0|        F|  region2|
|  F|  G| 2.0|        G|  region2|
+---+---+----+---------+---------+


New shortcuts computed:


                                                                                

+---+---+----+---------+---------+
|src|dst|cost|next_node|region_id|
+---+---+----+---------+---------+
|  A|  B| 1.0|        B|  region1|
|  A|  C| 3.0|        B|  region1|
|  A|  D| 4.0|        B|  region1|
|  B|  C| 2.0|        C|  region1|
|  B|  D| 3.0|        C|  region1|
|  C|  D| 1.0|        D|  region1|
|  E|  F| 1.0|        F|  region2|
|  E|  G| 3.0|        F|  region2|
|  F|  G| 2.0|        G|  region2|
+---+---+----+---------+---------+


Updated shortcuts table:
+---+---+----+---------+---------+
|src|dst|cost|next_node|region_id|
+---+---+----+---------+---------+
|  A|  B| 1.0|        B|  region1|
|  A|  C| 3.0|        B|  region1|
|  A|  D| 4.0|        B|  region1|
|  B|  C| 2.0|        C|  region1|
|  B|  D| 3.0|        C|  region1|
|  C|  D| 1.0|        D|  region1|
|  E|  F| 1.0|        F|  region2|
|  E|  G| 3.0|        F|  region2|
|  F|  G| 2.0|        G|  region2|
+---+---+----+---------+---------+

Total shortcuts: 9

ITERATION 2

Filtered shortcuts (cost < 10

In [None]:
def converging_check(old_paths: DataFrame, new_paths: DataFrame) -> bool:
    """
    Checks if two DataFrames of paths are identical.
    Used to determine convergence in the shortest path algorithm.
    """ 
    # Join on all columns to see if they match
    joined = new_paths.join(
        old_paths,
        on=["incoming_edge","outgoing_edge","cost"],
        how="anti_left"
    )
    return (joined.limit(1).count()==0)
    