In [1]:
import os
os.environ['JAVA_HOME'] = '/usr/lib/jvm/java-21-openjdk-amd64'

In [2]:
from pyspark.sql import SparkSession, DataFrame
from pyspark.sql import functions as F
from pyspark.sql.types import StructType, StructField, StringType, DoubleType, IntegerType
import h3

## 1- Initialize spark

In [3]:
def initialize_spark(app_name: str = "AllPairsShortestPath") -> SparkSession:
    """
    Initializes and returns a SparkSession.
    """
    spark = (
        SparkSession.builder.appName(app_name)
        .config("spark.driver.memory", "8g")
        .getOrCreate()
    )
    spark.sparkContext.setCheckpointDir("checkpoints")
    return spark

## 2- Read edges data and initialize shortcuts table

In [4]:
def read_edges(spark: SparkSession, file_path: str) -> DataFrame:
    """
    Reads edges from a CSV file into a DataFrame.
    """
    edges_df = spark.read.csv(file_path, header=True, inferSchema=True)
    edges_df = edges_df.select("id", "incoming_cell", "outgoing_cell", "lca_res")
    return edges_df

In [5]:
def initial_shortcuts_table(spark: SparkSession, file_path: str, edges_cost_df: DataFrame) -> DataFrame:
    """
    Creates the initial shortcuts table from the edges DataFrame.
    """
    shortcuts_df = spark.read.csv(file_path, header=True, inferSchema=True)
    shortcuts_df = shortcuts_df.select("incoming_edge", "outgoing_edge")
    
    # Add new next_edge column (initially same as outgoing_edge)
    shortcuts_df = shortcuts_df.withColumn("next_edge", F.col("outgoing_edge"))
    
    # Join with edges to get the cost
    shortcuts_df = shortcuts_df.join(
        edges_cost_df.select("id", "cost"),
        shortcuts_df["outgoing_edge"] == edges_cost_df["id"],
        "left"
    ).drop("id")
    
    return shortcuts_df