### initiate a spark context/session

In [1]:
import pyspark
from pyspark import SparkConf, SparkContext
sc = SparkContext(conf=SparkConf().setAppName('MyApp').setMaster('local'))

In [2]:
def parse_edge(s):
    user, follower = s.split('\t')
    return (int(user), int(follower))

In [3]:
edges = sc.textFile('/data/twitter/twitter_sample_small.txt')

In [4]:
edges = edges.map(parse_edge).cache()

In [5]:
edges.take(2)

[(12, 2241), (12, 13349)]

In [6]:
n = 4 # create partitions
forward_edges = edges.map(lambda e: (e[1], e[0])).partitionBy(n).persist() 

In [7]:
forward_edges.take(2)

[(755452, 12), (794748, 12)]

### focus on one instance vertex for demostration

In [8]:
x = 12
distances = sc.parallelize([(x, 0)])
distances.join(forward_edges).take(2)

[(12, (0, 126)), (12, (0, 380))]

### create a set of candidates - (key, value) where key is the next vertex of x and value is the distance

In [9]:
def step(item):
    prev_v, prev_d, next_v = item[0], item[1][0], item[1][1]
    return (next_v, prev_d +1)

In [10]:
candidates = distances.join(forward_edges).map(step)
candidates.take(2)

[(126, 1), (380, 1)]

### create a full join set where every key is the connected vertex to x and (the old and new distance) in the second tuple 

In [11]:
distances.fullOuterJoin(candidates).take(2)

[(648, (None, 1)), (12, (0, None))]

Footnote: reference for [fullOuterJoin](http://blog.cheyo.net/175.html)

### create a final set of (key, value) and every pair is the following vertex and corresponding distance

In [12]:
def complete(item):
    v, old_d, new_d = item[0], item[1][0], item[1][1]
    return (v, old_d if old_d is not None else new_d)

In [13]:
new_distances = distances.fullOuterJoin(candidates).map(complete)

In [14]:
new_distances.take(2)

[(648, 1), (12, 0)]

### create a loop

In [15]:
x = 12
d = 0
distances = sc.parallelize([(x,d)]).partitionBy(n)

In [16]:
while True:
    candidates = distances.join(forward_edges, n).map(step) # n is the partition
    # spark will not delete intermediate data after computation if we use persist()
    new_distances = distances.fullOuterJoin(candidates, n).map(complete).persist()
    count = new_distances.filter(lambda i: i[1] == d+1).count()
    
    if count > 0:
        d += 1
        distances = new_distances
        print("d = ", d, "count = ", count)
    else:
        break

('d = ', 1, 'count = ', 4)
('d = ', 2, 'count = ', 1)
('d = ', 3, 'count = ', 7)
('d = ', 4, 'count = ', 2)
('d = ', 5, 'count = ', 3)
('d = ', 6, 'count = ', 3)
('d = ', 7, 'count = ', 4)
('d = ', 8, 'count = ', 2)


### reconstruct shortest path

In [17]:
x = 12
path = 12

In [18]:
x_path = sc.parallelize([(x,tuple([path]))]).partitionBy(n)

In [19]:
x_path.collect()

[(12, (12,))]

### create a new set of candidates - (key, path) where key is the next vertex and path is the collection of vertexes

In [20]:
def step_on_path(item):
    prev_v, pathes, next_v = item[0], item[1][0], item[1][1]
    pathes += tuple([next_v])
    return (next_v, pathes)

In [21]:
candidates = x_path.join(forward_edges).map(step_on_path)

In [22]:
candidates.take(2)

[(126, (12, 126)), (380, (12, 380))]

### create a final set (key, path) where every pair is the following vertex and corresponding vertexes path 

In [23]:
def complete_path(item):
    v, old_path, new_path = item[0], item[1][0], item[1][1]
    return (v, old_path if old_path is not None else new_path)

In [24]:
while True:
    candidates = x_path.join(forward_edges, n).map(step_on_path)
    new_x_path = x_path.fullOuterJoin(candidates, n).map(complete_path, True).persist()
    count = new_x_path.filter(lambda i: i[0] == 34).count() # 34 is the ending vertex
    if count > 0:
        break
    else:
        x_path = new_x_path
        print('continue on new path')    

continue on new path
continue on new path
continue on new path
continue on new path
continue on new path
continue on new path
continue on new path


In [26]:
final_set_path = new_x_path.filter(lambda i: i[0] == 34)
print("number of pathes from vertex 12 to vertex 34:", final_set_path.count())

('number of pathes from vertex 12 to vertex 34:', 1)


In [27]:
final = list(final_set_path.collect()[0][1])

In [28]:
print(','.join(map(str, final)))

12,422,53,52,107,20,23,274,34
