In [1]:
from pyspark.sql.types import StructType, StructField, IntegerType
import pyspark.sql.functions as f

In [2]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.enableHiveSupport().master("local[2]").getOrCreate()

In [3]:
graph_schema = StructType([
    StructField("to_v", IntegerType(), False),
    StructField("from_v", IntegerType(), False)
])

In [4]:
dist_schema = StructType([
    StructField("vertex", IntegerType(), False),
    StructField("distance", IntegerType(), False)
])

In [7]:
def shortest_path(v_from, v_to, dataset_path=None):
    
    edges = spark.read.csv(dataset_path, sep="\t", schema=graph_schema)       
    edges.cache()
    
    dist = spark.createDataFrame(spark.sparkContext.parallelize([(v_from, 0)]), dist_schema)
    result = dist
    d = 0
    while True:
        
        candidates = dist.join(edges, dist.vertex==edges.from_v).select(edges.to_v.alias("vertex"), (dist.distance+1).alias("distance"))                  
        candidates.cache()
        
        result = dist.join(candidates, on="vertex", how="full_outer")\
        .select("vertex", f.when(dist.distance.isNotNull(), dist.distance).otherwise(candidates.distance).alias("distance"))               
        
        
        target = result.select(result.distance).where(result.vertex == v_to).collect()        
        
        if  target:
            return target[0]["distance"] 
            
        else:
            new_v = result.where(result.distance==d+1).count()
            if new_v > 0:
                d += 1            
                dist = candidates
            else:
                break                    

    return    
        

In [8]:
print shortest_path(12, 34, "/data/twitter/twitter_sample2.txt")

8
