In [1]:
from pyspark.sql import SparkSession, Window
from pyspark.sql.types import *
from pyspark.sql.functions import explode, collect_list, size, col, row_number, sort_array, udf, count

In [2]:
sparkSession = SparkSession.builder.enableHiveSupport().master("local").getOrCreate()
graphPath = "/data/graphDFSample"
reversedGraph = sparkSession.read.parquet(graphPath) \
    .withColumn("friend", explode('friends')) \
    .groupBy("friend") \
    .agg(collect_list("user").alias("users")) \
    .withColumn("users_size", size("users"))

#reversedGraph.show(3)

In [3]:
reversedGraph = reversedGraph.select(reversedGraph.friend, \
                                     sort_array(reversedGraph.users).alias("users_sorted"), \
                                     "users_size") \
                            .where(reversedGraph.users_size > '1')
#reversedGraph.show(3)

In [4]:
def serializer(input_array):
    out_array = []
    count = 0
    while count < len(input_array):
        for count_internal in range (count+1, len(input_array)):
            out_array.append((input_array[count], input_array[count_internal]))
        count += 1
    return(out_array)

In [None]:
serializer_udf = udf(lambda y: serializer(y), 
                     ArrayType(StructType(
                         (StructField("1", IntegerType(), True),
                          StructField("2", IntegerType(), True)))))

In [None]:
reversedGraph = reversedGraph.select(serializer_udf(reversedGraph.users_sorted).alias("users"), \
                                    "users_size")
#reversedGraph.show(3)

In [None]:
reversedGraph_2 = reversedGraph.select(explode(reversedGraph.users), \
                                      "users_size")
#reversedGraph_2.show(3)

In [None]:
reversedGraph_3 = reversedGraph_2.select('col.*')
#reversedGraph_3.show(10)

In [None]:
reversedGraph_4 = reversedGraph_3.groupBy("1", "2").agg(count("*").alias("count"))

#reversedGraph_4.show(5)

In [None]:
window = Window.orderBy(col("count").desc())
    
top50 = reversedGraph_4.withColumn("row_number", row_number().over(window)) \
            .filter(col("row_number") < 50) \
            .select(col("count"), col("1"), col("2")) \
            .orderBy(col("count").desc(), col("1").desc(), col("2").desc()) \
            .collect()

In [None]:
for val in top50:
    print ('%s %s %s' % (val))