In [1]:
from pyspark.sql.types import StructType, StructField, IntegerType
from pyspark.sql.functions import *

In [2]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.enableHiveSupport().master("local[2]").getOrCreate()

In [3]:
graph_schema = StructType([
    StructField("to_v", IntegerType(), False),
    StructField("from_v", IntegerType(), False)
])

In [4]:
dist_schema = StructType([
    StructField("vertex", IntegerType(), False),
    StructField("distance", IntegerType(), False)
])

In [5]:
def shortest_path(v_from, v_to, dataset_path=None):

    edges = spark.read.csv(dataset_path, sep="\t", schema=graph_schema)       
    edges.cache()

    distances = spark.createDataFrame([(v_from, 0)], dist_schema)
    d = 0
    while True:
        candidates = (distances
                      .join(edges, distances.vertex==edges.from_v)
                      .select(edges.to_v.alias("vertex"), (distances.distance+1).alias("distance")) 
                     ).cache()

        new_distances = (distances
                         .join(candidates, on="vertex", how="full_outer")
                         .select("vertex",
                                 when(
                                     distances.distance.isNotNull(), distances.distance
                                 ).otherwise(
                                     candidates.distance
                                 ).alias("distance"))
                        ).persist()
        
        count = new_distances.where(new_distances.distance==d+1).count()
        
        if count > 0:
            d += 1            
            distances = candidates
        else:
            break  
            
        target = (new_distances
                  .where(new_distances.vertex == v_to)
                 ).count()
        
        if  target > 0:
            break

    return d

In [6]:
d = shortest_path(12, 34, "/data/twitter/twitter_sample.txt")
print(d)

8
