In [1]:
from pyspark import SparkConf, SparkContext
sc = SparkContext(conf=SparkConf().setAppName("MyApp").setMaster("local[8]"))

In [2]:
def parse_edge(s):
    u, f = s.split("\t")
    return (int(u), int(f))

def step(i): 
    pv, pd, nv = i[0], i[1][0], i[1][1] 
    return (nv, pd + 1)

def complete(item): 
    v, od, nd = item[0], item[1][0], item[1][1]
    return (v, od if od is not None else nd)

def update_path(x):
    v, (old_path, (dist, new_v)) = x
    return new_v, old_path + (new_v,)

In [3]:
def shortest_path(v_from, v_to, dataset_path, numPartitions=10):
    edges = sc.textFile(dataset_path, numPartitions).map(parse_edge).cache()
    forward_edges = edges.map(lambda e: (e[1], e[0])).partitionBy(numPartitions).cache()
    
    d = 0
    distances = sc.parallelize([(v_from, d)]).partitionBy(numPartitions)
    paths = sc.parallelize([(v_from, (v_from,))])
    
    while True:
        candidates = distances.join(forward_edges, numPartitions).map(step)
        paths = paths.join(forward_edges).map(lambda x: (x[1][1], x[1][0] + (x[1][1],))).union(paths).distinct().cache()
        new_distances = distances.fullOuterJoin(candidates).map(complete).distinct().cache()
        count = new_distances.filter(lambda i: i[1] == d + 1).count()   
        if count > 0:
            d += 1     
            distances = new_distances
            # print "d = {}, count = {}".format(d, count)
        else:
            break
            
    result = paths.filter(lambda x: x[1][0] == v_from and x[1][-1] == v_to).collect()
    return ','.join(map(str, sorted(result, key=lambda x: len(x[1]))[0][1]))

In [5]:
%%time
shortest_path(12, 34, "/data/twitter/twitter_sample.txt")

CPU times: user 3.58 s, sys: 1.83 s, total: 5.41 s
Wall time: 11min 23s


'12,422,53,52,107,20,23,274,34'