In [None]:
sc.addPyFile("/opt/homebrew/Cellar/apache-spark/3.1.2/libexec/jars/graphframes-0.8.0-spark3.0-s_2.12.jar")

In [None]:
from graphframes import *
from pyspark.sql.functions import *

In [None]:
# Vertics DataFrame
v = spark.createDataFrame([
  ("a", "Alice", 34),
  ("b", "Bob", 36),
  ("c", "Charlie", 37),
  ("d", "David", 29),
  ("e", "Esther", 32),
  ("f", "Fanny", 38),
  ("g", "Gabby", 60)
], ["id", "name", "age"])

# Edges DataFrame
e = spark.createDataFrame([
  ("a", "b", "friend"),
  ("b", "c", "follow"),
  ("c", "b", "follow"),
  ("f", "c", "follow"),
  ("e", "f", "follow"),
  ("e", "d", "friend"),
  ("d", "a", "friend"),
  ("a", "e", "friend"),
  ("g", "e", "follow")
], ["src", "dst", "relationship"])

# Create a GraphFrame
g = GraphFrame(v, e)

g.vertices.show()
g.edges.show()

In [None]:
# g.vertices and g.edges are just DataFrames
# You can use any DataFrame API on them

g.edges.filter("src = 'a'").show()

In [None]:
g.edges.filter("src = 'a'").count()

In [None]:
# Count the number of followers of c.
# This queries the edge DataFrame.
print(g.edges.filter("relationship = 'follow' and dst = 'c'").count())

In [None]:
# A GraphFrame has additional attributes

g.outDegrees.show()

In [None]:
g.inDegrees.show()

In [None]:
g.inDegrees.explain()

In [None]:
myInDegrees = g.edges.groupBy('dst').count()\
               .withColumnRenamed('dst', 'id').withColumnRenamed('count', 'inDegree')
myInDegrees.show()

In [None]:
myInDegrees.explain()

In [None]:
print(g.inDegrees.storageLevel)

In [None]:
g.inDegrees.cache()

In [None]:
print(g.inDegrees.storageLevel)

In [None]:
print(g.vertices.storageLevel)

In [None]:
g.cache()

In [None]:
print(g.vertices.storageLevel)
print(g.edges.storageLevel)

In [None]:
# A triplet view of the graph

g.triplets.show()

In [None]:
g.triplets.explain()

### Motif Finding

In [None]:
# Search for pairs of vertices with edges in both directions between them.
motifs = g.find("(a)-[]->(b); (b)-[]->(a)").filter('a.id < b.id')
motifs.show()

In [None]:
# Find triangles

triangles = g.find("(a)-[]->(b); (b)-[]->(c); (c)-[]->(a)")
triangles = triangles.filter("a.id < b.id AND a.id < c.id")
triangles.show()

In [None]:
triangles.explain()

In [None]:
# Negation
oneway = g.find("(a)-[]->(b); !(b)-[]->(a)")
oneway.show()

In [None]:
# Find vertices without incoming edges:
g.find("!()-[]->(a)").show()

In [None]:
# More meaningful queries can be expressed by applying filters.
# Question: where is this filter applied?

g.find("(a)-[e]->(b); (b)-[]->(a)").filter("b.age > 36").show()

In [None]:
g.find("(a)-[]->(b); (b)-[]->(a)").filter("b.age > 36").explain()

In [None]:
# Find chains of 4 vertices such that at least 2 of the 3 edges are "friend" relationships.
# The when function is similar to the CASE WHEN in SQL

chain4 = g.find("(a)-[e1]->(b); (b)-[e2]->(c); (c)-[e3]->(d)").where('a!=d AND a!=c AND b!=d')

friendTo1 = lambda e: when(e['relationship'] == 'friend', 1).otherwise(0)

chain4.select('*',friendTo1(chain4['e1']).alias('f1'), \
                  friendTo1(chain4['e2']).alias('f2'), \
                  friendTo1(chain4['e3']).alias('f3')) \
      .where('f1 + f2 + f3 >= 2').select('a', 'b', 'c', 'd').show()

### Subgraphs

In [None]:
# Select subgraph of users older than 30, and relationships of type "friend".
# Drop isolated vertices (users) which are not contained in any edges (relationships).

g1 = g.filterVertices("age > 30").filterEdges("relationship = 'friend'")\
      .dropIsolatedVertices()

g1.vertices.show()
g1.edges.show()

In [None]:
# Select subgraph based on edges "e" of type "follow"
# pointing from a younger user "a" to an older user "b".

paths = g.find("(a)-[e]->(b)")\
  .filter("e.relationship = 'follow'")\
  .filter("a.age < b.age")

paths.show()
# "paths" contains vertex info. Extract the edges.

e2 = paths.select("e.*")
e2.show()

# Construct the subgraph
g2 = GraphFrame(g.vertices, e2).dropIsolatedVertices()

g2.vertices.show()
g2.edges.show()

### BFS

In [None]:
# Starting vertex is 'a'
layers = [g.vertices.select('id').where("id = 'a'")]
visited =  layers[0]

while layers[-1].count() > 0:
    # From the current layer, get all the one-hop neighbors
    d1 = layers[-1].join(g.edges, layers[-1]['id'] == g.edges['src'])
    # Rename the column as 'id', and remove visited verices and duplicates
    d2 = d1.select(d1['dst'].alias('id')) \
           .subtract(visited).distinct().cache()
    layers += [d2]
    visited = visited.union(layers[-1]).cache()

In [None]:
layers[0].show()

In [None]:
layers[1].show()

In [None]:
layers[2].show()

In [None]:
layers[3].show()

In [8]:
# GraphFrames provides own BFS:

paths = g.bfs("id = 'a'", "age > 36")
paths.show()

+--------------+--------------+---------------+--------------+----------------+
|          from|            e0|             v1|            e1|              to|
+--------------+--------------+---------------+--------------+----------------+
|{a, Alice, 34}|{a, e, friend}|{e, Esther, 32}|{e, f, follow}|  {f, Fanny, 38}|
|{a, Alice, 34}|{a, b, friend}|   {b, Bob, 36}|{b, c, follow}|{c, Charlie, 37}|
+--------------+--------------+---------------+--------------+----------------+



### List Ranking

In [None]:
# -1 denotes end of list
data = [(0, 5), (1, 0), (3, 4), (4, 6), (5, -1), (6,1)]
e = spark.createDataFrame(data, ['src', 'dst'])
v = e.select(col('src').alias('id'), when(e.dst == -1, 0).otherwise(1).alias('d'))
v1 = spark.createDataFrame([(-1, 0)], ['id', 'd'])
v = v.union(v1)
v.show()
e.show()

In [None]:
while e.filter('dst != -1').count() > 0:
    g = GraphFrame(v, e)
    g.cache()
    v = g.triplets.select(col('src.id').alias('id'), 
                          (col('src.d') + col('dst.d')).alias('d')) \
         .union(v1)
    e = g.find('(a)-[]->(b); (b)-[]->(c)') \
         .select(col('a.id').alias('src'), col('c.id').alias('dst')) \
         .union(e.filter('dst = -1'))
    e.show()
v.show()

### Message passing via AggregateMessages

In [None]:
from pyspark.sql.functions import coalesce, col, lit, sum, when, min, max
from graphframes.lib import AggregateMessages as AM

# AggregateMessages has the following members: src, dst, edge, msg
# For each user, sum the ages of the adjacent users.
agg = g.aggregateMessages(
    sum(AM.msg).alias("summedAges"),
    #sendToSrc = AM.dst['age'],
    sendToDst = AM.src['age'])
agg.show()

### The Pregel Model for Graph Computation

In [None]:
# Pagerank in the Pregel model 

from pyspark.sql.functions import coalesce, col, lit, sum, when, min
from graphframes.lib import Pregel

# Need to set up a directory for Pregel computation
sc.setCheckpointDir("checkpoint")

'''
Use builder pattern to describe the operations.
Call run() to start a run. It returns a DataFrame of vertices from the last iteration.

When a run starts, it expands the vertices DataFrame using column expressions 
defined by withVertexColumn(). Those additional vertex properties can be 
changed during Pregel iterations. In each Pregel iteration, there are three 
phases:

* Given each edge triplet, generate messages and specify target vertices to 
  send, described by sendMsgToDst() and sendMsgToSrc().
* Aggregate messages by target vertex IDs, described by aggMsgs().
* Update additional vertex properties based on aggregated messages and states 
  from previous iteration, described by withVertexColumn().
'''
v = g.outDegrees
g = GraphFrame(v,e)
ranks = g.pregel \
        .setMaxIter(5) \
        .sendMsgToDst(Pregel.src("rank") / Pregel.src("outDegree")) \
        .aggMsgs(sum(Pregel.msg())) \
        .withVertexColumn("rank", lit(1.0), \
            coalesce(Pregel.msg(), lit(0.0)) * lit(0.85) + lit(0.15)) \
        .run()
ranks.show()

# pyspark.sql.functions.coalesce(*cols): Returns the first column that is not null.
# Not to be confused with spark.sql.coalesce(numPartitions)


In [None]:
# BFS in the Pregel model

g = GraphFrame(v,e)

dist = g.pregel \
        .sendMsgToDst(when(Pregel.src('active'), Pregel.src('d') + 1)) \
        .aggMsgs(min(Pregel.msg())) \
        .withVertexColumn('d', when(v['id'] == 'a', 0).otherwise(99999), \
            when(Pregel.msg() < col('d'), Pregel.msg()).otherwise(col('d'))) \
        .withVertexColumn('active', when(v['id'] == 'a', True).otherwise(False), \
            when(Pregel.msg() < col('d'), True).otherwise(False)) \
        .run()
dist.show()
