In [44]:
from pyspark.sql.types import StructType, StructField, IntegerType
import pyspark.sql.functions as f

In [45]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.enableHiveSupport().master("local[2]").getOrCreate()

In [46]:
graph_schema = StructType([
    StructField("to_v", IntegerType(), False),
    StructField("from_v", IntegerType(), False)
])

In [47]:
dist_schema = StructType([
    StructField("vertex", IntegerType(), False),
    StructField("distance", IntegerType(), False)
])

In [103]:
def shortest_path(v_from, v_to, dataset_path=None):
    
    edges = spark.read.csv(dataset_path, sep="\t", schema=graph_schema)       
    edges.cache()
    
    dist = spark.createDataFrame(spark.sparkContext.parallelize([(v_from, 0)]), dist_schema)
    result = dist
    d = 0
    while True:
        
        candidates = dist.join(edges, dist.vertex==edges.from_v).select(edges.to_v.alias("vertex"), (dist.distance+1).alias("distance"))                  
    
        result = dist.join(candidates, on="vertex", how="full_outer")\
        .select("vertex", f.when(dist.distance.isNotNull(), dist.distance).otherwise(candidates.distance).alias("distance"))        
       
        new_v = result.where(result.distance==d+1).count()
        
        if  new_v> 0:
            d += 1
            dist = candidates
            print("d = ", d, "count = ", new_v)
        else:
            break                    

    return result.select(result.distance).where(result.vertex == v_to).collect()[0]["distance"]    
        

In [104]:
print shortest_path(12, 34, "/data/twitter/twitter_sample2.txt")

('d = ', 1, 'count = ', 4)
('d = ', 2, 'count = ', 4)
('d = ', 3, 'count = ', 14)
('d = ', 4, 'count = ', 23)
('d = ', 5, 'count = ', 52)
('d = ', 6, 'count = ', 62)
('d = ', 7, 'count = ', 56)
('d = ', 8, 'count = ', 38)
8


In [8]:
!head /data/twitter/twitter_sample2.txt

12	18
12	41
12	57
12	62
12	235
12	278
12	291
12	338
12	456
12	614
