In [248]:
import os
os.environ['JAVA_HOME'] = '/usr/lib/jvm/java-21-openjdk-amd64'

In [249]:
from pyspark.sql import SparkSession, DataFrame
from pyspark.sql import functions as F
from pyspark.sql.types import StructType, StructField, StringType, DoubleType, IntegerType
import h3
from pyspark.sql import Window

## 1- Initialize spark

In [250]:
def initialize_spark(app_name: str = "AllPairsShortestPath") -> SparkSession:
    """
    Initializes and returns a SparkSession.
    """
    spark = (
        SparkSession.builder.appName(app_name)
        .config("spark.driver.memory", "8g")
        .getOrCreate()
    )
    spark.sparkContext.setCheckpointDir("checkpoints")
    return spark

## 2- Read edges data and initialize shortcuts table

In [251]:
def read_edges(spark: SparkSession, file_path: str) -> DataFrame:
    """
    Reads edges from a CSV file into a DataFrame.
    """
    edges_df = spark.read.csv(file_path, header=True, inferSchema=True)
    edges_df = edges_df.select("id", "incoming_cell", "outgoing_cell", "lca_res")
    return edges_df

In [252]:
def initial_shortcuts_table(spark: SparkSession, file_path: str, edges_cost_df: DataFrame) -> DataFrame:
    """
    Creates the initial shortcuts table from the edges DataFrame.
    """
    shortcuts_df = spark.read.csv(file_path, header=True, inferSchema=True)
    shortcuts_df = shortcuts_df.select("incoming_edge", "outgoing_edge")
    # Add new via_edge column
    shortcuts_df = shortcuts_df.withColumn("via_edge", F.col("outgoing_edge"))
    
    # shortcuts_df.show(5)
    shortcuts_df = shortcuts_df.join(
        edges_cost_df.select("id", "cost"), 
        shortcuts_df.incoming_edge == edges_cost_df.id, 
        "left"
    ).drop(edges_cost_df.id)
   
    return shortcuts_df

## 3- Update edges cost by dummy function

In [253]:
@F.udf(returnType=DoubleType())
def dummy_cost(length, maxspeed):
    """Calculates cost based on length and maxspeed."""
    if maxspeed <= 0:
        return float('inf')  # Return infinity for invalid speeds
    # Example calculation: Cost is proportional to length and inversely proportional to speed
    return float(length) / float(maxspeed)

def update_dummy_costs_for_edges(spark: SparkSession, file_path: str, edges_df: DataFrame):
    """
    Adds a dummy cost column to the edges DataFrame.
    """
    edges_df_cost = spark.read.csv(file_path, header=True, inferSchema=True).select("id", "length", "maxspeed")
    edges_df_cost = edges_df_cost.withColumn(
        "cost", 
        dummy_cost(F.col("length"), F.col("maxspeed"))
    )
    edges_df = edges_df.drop("cost")  # Remove existing cost column if present
    edges_df = edges_df.join(edges_df_cost.select("id", "cost"), on="id", how="left")
    return edges_df

## 4- Add info "via_cell", "via_res" and "lca_res" to shortcuts table

In [254]:
# LCA - Lowest Common Ancestor resolution
@F.udf(StringType())
def find_LCA(cell1, cell2):
    cell1_res = h3.get_resolution(cell1)
    cell2_res = h3.get_resolution(cell2)
    lca_res = min(cell1_res, cell2_res)
    while h3.cell_to_parent(cell1, lca_res) != h3.cell_to_parent(cell2, lca_res) and lca_res > 0:
        lca_res -= 1
    mycell = h3.cell_to_parent(cell1, lca_res)
    if mycell == h3.cell_to_parent(cell2, lca_res):
        return mycell
    else:
        return None  # maybe return 0 is better option

# Resolution of a cell
@F.udf(IntegerType())  
def find_resolution(cell):
    if cell is None:
        return -1
    return h3.get_resolution(cell)

def add_info_for_shortcuts(spark: SparkSession, shortcuts_df: DataFrame, edges_df: DataFrame) -> DataFrame:
    """
    Adds additional information to the shortcuts DataFrame.
    Add incoming_cell and lca_res from incoming_edge to shortcuts_df
    Add outgoing_cell and lca_res from outgoing_edge to shortcuts_df
    """
    
    # Drop existing columns if present: lca_res, via_cell, via_res
    for col in ["lca_res", "via_cell", "via_res"]:
        if col in shortcuts_df.columns:
            shortcuts_df = shortcuts_df.drop(col)
            
    # Join to get incoming_cell and lca for incoming_edge
    shortcuts_df = shortcuts_df.join(
        edges_df.select(
            F.col("id").alias("incoming_edge_id"),
            F.col("incoming_cell").alias("incoming_cell_in"),
            F.col("lca_res").alias("lca_res_in")
        ),
        shortcuts_df.incoming_edge == F.col("incoming_edge_id"),
        "left"
    ).drop("incoming_edge_id")
    
    # Join to get outgoing_cell and lca for outgoing_edge
    shortcuts_df = shortcuts_df.join(
        edges_df.select(
            F.col("id").alias("outgoing_edge_id"),
            F.col("outgoing_cell").alias("outgoing_cell_out"),
            F.col("lca_res").alias("lca_res_out")
        ),
        shortcuts_df.outgoing_edge == F.col("outgoing_edge_id"),
        "left"
    ).drop("outgoing_edge_id")
    
    # Add lca_res column as the maximum of lca_res_in and lca_res_out
    shortcuts_df = shortcuts_df.withColumn(
        "lca_res",
        F.greatest(F.col("lca_res_in"), F.col("lca_res_out"))
    )
    # Drop intermediate lca_res_in and lca_res_out columns
    shortcuts_df = shortcuts_df.drop("lca_res_in", "lca_res_out")
    
    # Add via_cell column as LCA of incoming_cell_in and outgoing_cell_out
    shortcuts_df = shortcuts_df.withColumn(
        "via_cell", find_LCA(F.col("incoming_cell_in"), F.col("outgoing_cell_out"))
    )   
    shortcuts_df = shortcuts_df.withColumn(
        "via_res", find_resolution(F.col("via_cell"))
    )
    # Drop intermediate incoming_cell_in and outgoing_cell_out columns
    shortcuts_df = shortcuts_df.drop("incoming_cell_in", "outgoing_cell_out")   
    
    return shortcuts_df

## 5- Filter shortcuts table based on current resolution

In [255]:
def filter_shortcuts_by_resolution(shortcuts_df: DataFrame, current_res: int) -> DataFrame:
    """
    Filters the shortcuts DataFrame based on the current resolution.
    Keeps only the shortcuts where lca_res is greater than or equal to current_res.
    """
    filtered_shortcuts_df = shortcuts_df.filter(
        (F.col("lca_res") <= current_res) & (F.col("via_res") >= current_res)
    )
    return filtered_shortcuts_df

## 6- add cell container for each shortcut based on current resolution

In [256]:
def add_parent_cell_at_resolution(shortcuts_df: DataFrame, current_resolution: int) -> DataFrame:
    """
    Adds a cell container column to shortcuts based on the current resolution.
    
    The cell container is computed as the parent cell of the via_cell at the 
    current resolution level. This allows grouping shortcuts by their spatial
    containment at different H3 resolution levels.
    
    Args:
        shortcuts_df: DataFrame with shortcuts containing via_cell column
        current_resolution: The H3 resolution level to use for cell containers
        
    Returns:
        DataFrame with added 'current_cell' column containing the parent cell
    """
    
    @F.udf(StringType())
    def get_parent_cell(cell, target_resolution):
        """
        Get the parent cell of a given cell at the target resolution.
        Returns None if the cell is None or if the resolution is invalid.
        """
        if cell is None:
            return None
        try:
            cell_res = h3.get_resolution(cell)
            # If current resolution is higher than cell resolution, return the cell itself
            if target_resolution >= cell_res:
                return cell
            # Otherwise, get the parent at the target resolution
            return h3.cell_to_parent(cell, target_resolution)
        except Exception:
            return None
    
    # Add the current_cell column based on via_cell
    shortcuts_df = shortcuts_df.withColumn(
        "current_cell",
        get_parent_cell(F.col("via_cell"), F.lit(current_resolution))
    )
    shortcuts_df = shortcuts_df.drop("lca_res", "via_cell", "via_res")
    # Filter out rows where current_cell is None (invalid cells)
    # shortcuts_df = shortcuts_df.filter(F.col("current_cell").isNotNull())
    
    return shortcuts_df

# 7- shortest path in pure spark

In [257]:
from pyspark.sql import DataFrame
from pyspark.sql import functions as F

def update_convergence_status(
    shortcuts_df_last: DataFrame,
    shortcuts_df_new: DataFrame
) -> DataFrame:
    """
    Overwrites the 'is_converged' column in shortcuts_df_new.
    
    A 'current_cell' group is marked True if none of its associated rows 
    (based on the three join keys) have changed between 'last' and 'new'.
    
    The resulting DataFrame retains all original columns, including the updated 'is_converged'.
    """
    
    # The keys that define a shortcut's identity for change detection
    join_keys = ["incoming_edge", "outgoing_edge", "cost"]

    # --- 1. Identify all ACTIVE (Non-Converged) current_cell IDs ---
    
    # Use LEFT ANTI-JOIN to find rows in 'new' that are new or changed
    non_converged_rows = shortcuts_df_new.join(
        shortcuts_df_last,
        on=join_keys,
        how="left_anti"
    )

    # Extract the unique 'current_cell' IDs that are involved in a change.
    active_cell_ids = non_converged_rows.select("current_cell").distinct().alias("active")
    
    # --- 2. Flag the Final DataFrame (Overwrite existing 'is_converged') ---

    # Start with the original new DataFrame
    result_df = shortcuts_df_new.alias("new")

    # Perform a LEFT OUTER join with the list of ACTIVE cell IDs.
    # We join on "current_cell" because convergence is defined at the cell group level.
    result_df = result_df.join(
        active_cell_ids.alias("active"),
        on="current_cell",
        how="left_outer"
    )
    
    # Update the 'is_converged' column:
    # If the current_cell ID matched an 'active' ID (col is not null), it's NOT converged (False).
    # If the current_cell ID did not match (col is null), it IS converged (True).
    final_df = result_df.withColumn(
        "is_converged",
        F.when(
            F.col("active.current_cell").isNull(),
            F.lit(True)
        ).otherwise(
            F.lit(False)
        )
    ).drop("active.current_cell") # Clean up the temporary join column used for the flag

    return final_df

In [258]:
from pyspark.sql import DataFrame
from pyspark.sql import functions as F

# NOTE: The necessary 'update_convergence_status' function definition is assumed to exist.

def run_grouped_shortest_path_with_convergence(shortcuts_df: DataFrame) -> DataFrame:
    
    
    # 1. INITIALIZE: Start with the original edges (shortcuts_df)
    current_paths = shortcuts_df.select(
        "incoming_edge", "outgoing_edge", "cost", "via_edge", "current_cell",
    )
    current_paths = current_paths.withColumn(
        "is_converged",
        F.lit(False)
    ).cache()

    #while True:
    for i in range(4):
        # --- 2. PATH EXTENSION (Find new, two-hop paths) ---
        new_paths = current_paths.alias("L").join(
            current_paths.alias("R"),
            # CRITICAL CORRECTION: Join on connection AND same cell group
            [
                F.col("L.outgoing_edge") == F.col("R.incoming_edge"),
                F.col("L.current_cell") == F.col("R.current_cell")
            ],
            "inner"
        ).filter(
            # 1. No self-loops for the entire path
            (F.col("L.incoming_edge") != F.col("R.outgoing_edge")) &
            # 2. Only extend paths from non-converged edges
            (F.col("L.is_converged") == F.lit(False)) &
            (F.col("R.is_converged") == F.lit(False))
        ).select(
            F.col("L.incoming_edge").alias("incoming_edge"),
            F.col("R.outgoing_edge").alias("outgoing_edge"),
            (F.col("L.cost") + F.col("R.cost")).alias("cost"),
            F.col("L.outgoing_edge").alias("via_edge"),
            F.col("L.current_cell").alias("current_cell"),
            F.lit(False).alias("is_converged"), # New path is NOT converged
        ).cache()

        if new_paths.limit(1).count()==0:
            print("======= len new_path is zero!! ==========")
            break
        # --- 3. COST MINIMIZATION (Union and GroupBy) ---
        all_paths = current_paths.unionByName(new_paths)
        ########
        single_step_window = Window.partitionBy(
            "incoming_edge", 
            "outgoing_edge", 
            "current_cell"
        ).orderBy(
            F.col("cost").asc(), 
            F.col("via_edge").asc() # Tie-breaker
        )

        # Select the single best path (Rank 1)
        next_paths = all_paths.withColumn(
            "rnk", 
            F.row_number().over(single_step_window)
        ).filter(
            F.col("rnk") == 1
        ).drop("rnk")
        
        # CRITICAL CORRECTION: Group by all identifying keys (start, end, AND group)
        #min_costs = all_paths.groupBy("incoming_edge", "outgoing_edge", "current_cell").agg(
        #    F.min("cost").alias("min_cost")
        #).cache()

        # Join back to filter to only the rows that match the minimum cost
        #next_paths = all_paths.alias("T1").join(
        #    min_costs.alias("T2"),
            # CRITICAL CORRECTION: Use correct column names for join
        #    [
        #        F.col("T1.incoming_edge") == F.col("T2.incoming_edge"),
        #        F.col("T1.outgoing_edge") == F.col("T2.outgoing_edge"),
        #        F.col("T1.current_cell") == F.col("T2.current_cell"),
        #        F.col("T1.cost") == F.col("T2.min_cost")
        #    ],
        #    "inner"
        #).select(
            # Use correct column names
        #    F.col("T1.incoming_edge"),
        #    F.col("T1.outgoing_edge"),
        #    F.col("T2.min_cost").alias("cost"),
        #    F.col("T1.via_edge"),
        #    F.col("T1.current_cell"),
        #    F.col("T1.is_converged")
        #).cache()
        
        # Unpersist DataFrames used only in this iteration
        current_paths.unpersist()
        new_paths.unpersist()
        #min_costs.unpersist()

        # --- 4. CONVERGENCE CHECK (The provided logic) ---
        
        # NOTE: Your logic requires 'update_convergence_status' to check if the path 
        # set is stable between iterations.
        
        # The update function is designed to check for changes to the full set of
        # paths between two states. Here, we check the stability of the paths:
        
        # ..........
        #next_paths = update_convergence_status(next_paths, current_paths)
        # .........
        
        # Check if ALL rows in the resulting column are TRUE (meaning no changes occurred)
        convergence_check_result = next_paths.select(
            F.min(F.col("is_converged")).alias("all_converged")
        ).collect()[0]["all_converged"]
        
        if convergence_check_result:
            print(f"âœ… {i} - Graph fully converged. Exiting shortest path calculation.")
            #break
            
        current_paths = next_paths
    return current_paths.drop("is_converged", "current_cell")

# 8- merge

In [259]:
def merge_shortcuts_to_main_table(
    main_df: DataFrame,
    new_shortcuts: DataFrame,
) -> DataFrame:
    """
    Merges newly computed shortcuts back into the main table.
    
    Strategy:
    1. Remove old shortcuts for the processed partitions
    2. Add new shortcuts
    
    Args:
        main_df: Main shortcut table
        new_shortcuts: Newly computed shortcuts
        partition_columns: Columns used for partitioning
    
    Returns:
        Updated DataFrame with new shortcuts
    """
    
    # Remove old shortcuts for these partitions using left_anti join
    #
    # insteed of using this join function, we can just use the part of main_df that filtered out
    main_df = main_df.select("incoming_edge", "outgoing_edge", "cost", "via_edge")
    new_shortcuts = new_shortcuts.select("incoming_edge", "outgoing_edge", "cost", "via_edge")

    remaining_df = main_df.alias("main").join(
        new_shortcuts.alias("update"),
        on=(
            (F.col("main.incoming_edge") == F.col("update.incoming_edge")) &
            (F.col("main.outgoing_edge") == F.col("update.outgoing_edge"))
        ),
        how="left_anti"
    )    
    
    # Combine with new shortcuts
    updated_df = remaining_df.unionByName(new_shortcuts)
    
    return updated_df

In [260]:
spark = initialize_spark()
edges_df = read_edges(spark, file_path="data/burnaby_driving_simplified_edges_with_h3.csv").cache()
edges_cost_df = update_dummy_costs_for_edges(spark, file_path="data/burnaby_driving_simplified_edges_with_h3.csv", edges_df=edges_df)
shortcuts_df = initial_shortcuts_table(spark, file_path="data/burnaby_driving_edge_graph.csv", edges_cost_df=edges_cost_df)


25/11/18 15:40:49 WARN CacheManager: Asked to cache already cached data.


In [None]:
for current_resolution in range(14,13,-1):
    shortcuts_df = add_info_for_shortcuts(spark, shortcuts_df, edges_df) 
    shortcuts_df = shortcuts_df.checkpoint()
    shortcuts_df.show(5)
    shortcuts_df_filtered =filter_shortcuts_by_resolution(shortcuts_df,current_res=current_resolution)
    shortcuts_df_with_cell =add_parent_cell_at_resolution(shortcuts_df_filtered, current_resolution)
    
    shortcuts_df_new = run_grouped_shortest_path_with_convergence( shortcuts_df_with_cell)
    #shortcuts_df = merge_shortcuts_to_main_table(shortcuts_df, shortcuts_df_new)
    print(f"number of shortcut in resolution {current_resolution} : {shortcuts_df_new.count()}")
    #shortcuts_df_new.count()
#spark.stop()

+--------------------+--------------------+--------------------+------------------+-------+---------------+-------+
|       incoming_edge|       outgoing_edge|            via_edge|              cost|lca_res|       via_cell|via_res|
+--------------------+--------------------+--------------------+------------------+-------+---------------+-------+
|(8703571310, 8703...|(8703571308, 8703...|(8703571308, 8703...|1.7882000000000002|     10|8f28de881760406|     15|
|(7314138295, 6092...|(6092498237, 7314...|(6092498237, 7314...|             0.992|     11|8f28de8886c3b50|     15|
|(9068804417, 9068...|(9068804425, 9463...|(9068804425, 9463...|3.4905999999999997|     11|8f28de8982e6759|     15|
|(9289423408, 9289...|(9289423412, 9289...|(9289423412, 9289...|            2.6675|     11|8f28de1228c9309|     15|
|(416104062, 34741...|(347415013, 25364...|(347415013, 25364...|1.9208666666666665|      9|8f28de898932612|     15|
+--------------------+--------------------+--------------------+--------

In [270]:
spark.stop()

In [None]:
#df = shortcuts_df_new.toPandas()