<a href="https://colab.research.google.com/github/Brent-Morrison/PySpark_examples/blob/main/connected_components_methods.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Introduction

This notebook will document various methods for identifying connected components from a list of paired records.  This problem is best understood with an example.  
  
Consider a matching process identifying related records.  This is represented in the table below.  Record 1 is matched to record 2, and record 2 is matched to record 3.  These three make a group of connected components.  Records 4 and 5 also represent a group.
<br>

| record 1 	| record 2 	| match key 	|
|----------	|----------	|-----------	|
| 1        	| 2        	| 1_2_3     	|
| 2        	| 3        	| 1_2_3     	|
| 4        	| 5        	| 4_5       	|
<br>

Assigning a group label (```match key``` above) is simple when the pairs of records are each distinct over the full data set.

Assigning a group label can be challenging however in situations where more than two records are matched.  

In graph theory this is called finding connected components.

Below, this will be implemented in Spark via the GraphFrames module, in Spark  iterating over dataframes, and also in Python.

The mock data used below results in the following connected components.  
  
_```(1, 2, 3, 30)```_  
_```(4, 5)```_  
_```(6, 7, 8, 9, 23)```_  
_```(10, 11, 12, 15, 17, 18, 20, 21)```_  
<br>
_Note - I haven't re-shaped each of the outputs below to a standard format in the interests of time.  This is relatively easy with some standard data munging._
<br>
# Set up

Download Spark and install

In [None]:
!apt-get install openjdk-11-jdk-headless -qq > /dev/null
!wget -q https://downloads.apache.org/spark/spark-3.0.2/spark-3.0.2-bin-hadoop2.7.tgz
!tar xf spark-3.0.2-bin-hadoop2.7.tgz
!rm -rf spark-3.0.2-bin-hadoop2.7.tgz*
!pip -q install findspark pyspark graphframes

[K     |████████████████████████████████| 212.3MB 72kB/s 
[K     |████████████████████████████████| 204kB 61.0MB/s 
[K     |████████████████████████████████| 163kB 45.0MB/s 
[?25h  Building wheel for pyspark (setup.py) ... [?25l[?25hdone


Set the environment variables so that Colab can find Spark

In [None]:
import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-11-openjdk-amd64"
os.environ["SPARK_HOME"] = "/content/spark-3.0.2-bin-hadoop2.7"
os.environ["HADOOP_HOME"] = os.environ["SPARK_HOME"]
os.environ["PYSPARK_SUBMIT_ARGS"] = "--packages graphframes:graphframes:0.8.1-spark3.0-s_2.12 pyspark-shell"

Add PySpark to sys.path

PySpark isn't on sys.path by default, but that doesn't mean it can't be used as a regular library. You can address this by either symlinking pyspark into your site-packages, or adding pyspark to sys.path at runtime. [findspark](https://github.com/minrk/findspark) does the latter.

In [None]:
import findspark
findspark.init()

Create the Spark session

In [None]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.master("local[*]").getOrCreate()  #.appName('test')
sc = spark.sparkContext

Prior to using PySpark we need the required classes from the PySpark sql module and the GraphFrames module.

In [None]:
from pyspark.sql import functions as F
from pyspark.sql import Window as W
from pyspark.sql import Column as C
from pyspark.sql import GroupedData as G
from pyspark.sql import DataFrame
from pyspark.sql.types import *
from graphframes import *

### Data

Load mock data to a PySpark data frame

In [None]:
# In a list
raw_matches_list = [
    [1,2]
    ,[1,30]
    ,[2,3]
    ,[2,3]
    ,[4,5]
    ,[6,7]
    ,[7,8]
    ,[8,9]
    ,[10,12]
    ,[10,15]
    ,[10,17]
    ,[11,15]
    ,[15,20]
    ,[17,18]
    ,[20,21]
    ,[23,8]
  ]

In [None]:
# Data frame
raw_matches = spark.createDataFrame(
  raw_matches_list,
  ['ecid_1', 'ecid_2', ] 
)
raw_matches.show(20,truncate=False)

+------+------+
|ecid_1|ecid_2|
+------+------+
|1     |2     |
|1     |30    |
|2     |3     |
|2     |3     |
|4     |5     |
|6     |7     |
|7     |8     |
|8     |9     |
|10    |12    |
|10    |15    |
|10    |17    |
|11    |15    |
|15    |20    |
|17    |18    |
|20    |21    |
|23    |8     |
+------+------+



Create vertices and edges data frames using the data above

In [None]:
edges = raw_matches.withColumnRenamed('ecid_1','src').withColumnRenamed('ecid_2','dst').withColumn('strength', F.lit(1))

edges.show(truncate=False)

+---+---+--------+
|src|dst|strength|
+---+---+--------+
|1  |2  |1       |
|1  |30 |1       |
|2  |3  |1       |
|2  |3  |1       |
|4  |5  |1       |
|6  |7  |1       |
|7  |8  |1       |
|8  |9  |1       |
|10 |12 |1       |
|10 |15 |1       |
|10 |17 |1       |
|11 |15 |1       |
|15 |20 |1       |
|17 |18 |1       |
|20 |21 |1       |
|23 |8  |1       |
+---+---+--------+



In [None]:
vertices = (raw_matches
  .select('ecid_1')
  .unionAll(raw_matches.select('ecid_2'))
  .distinct()
  .withColumnRenamed('ecid_1','id')
  .withColumn('block_key', F.lit('block_key_01'))
)

vertices.show(truncate=False)    

+---+------------+
|id |block_key   |
+---+------------+
|7  |block_key_01|
|6  |block_key_01|
|9  |block_key_01|
|17 |block_key_01|
|5  |block_key_01|
|1  |block_key_01|
|10 |block_key_01|
|3  |block_key_01|
|12 |block_key_01|
|8  |block_key_01|
|11 |block_key_01|
|2  |block_key_01|
|4  |block_key_01|
|18 |block_key_01|
|21 |block_key_01|
|15 |block_key_01|
|30 |block_key_01|
|23 |block_key_01|
|20 |block_key_01|
+---+------------+



# Connected components using GraphFrames

Create GraphFrame

In [None]:
g_ecid = GraphFrame(vertices, edges)

Call connected components function.  Note that this requires setting a checkpoint.

In [None]:
sc.setCheckpointDir("/tmp/graphframes_eg1_cc")
gcc = g_ecid.connectedComponents()

Result

In [None]:
gcc.orderBy('component', 'id').show(20, truncate=False)

+---+------------+---------+
|id |block_key   |component|
+---+------------+---------+
|1  |block_key_01|1        |
|2  |block_key_01|1        |
|3  |block_key_01|1        |
|30 |block_key_01|1        |
|4  |block_key_01|4        |
|5  |block_key_01|4        |
|6  |block_key_01|6        |
|7  |block_key_01|6        |
|8  |block_key_01|6        |
|9  |block_key_01|6        |
|23 |block_key_01|6        |
|10 |block_key_01|10       |
|11 |block_key_01|10       |
|12 |block_key_01|10       |
|15 |block_key_01|10       |
|17 |block_key_01|10       |
|18 |block_key_01|10       |
|20 |block_key_01|10       |
|21 |block_key_01|10       |
+---+------------+---------+



# Connected components using Python

The function below uses recursion to iterate through each tuple

In [None]:
def connected_components(pairs):
    # build a graph using the pairs
    nodes = defaultdict(lambda: [])
    for a, b in pairs:
        if b is not None:
            nodes[a].append((b, nodes[b]))
            nodes[b].append((a, nodes[a]))
        else:
            nodes[a]  # empty list

    # add all neighbors to the same group    
    visited = set()
    def _build_group(key, group):
        if key in visited:
            return
        visited.add(key)
        group.add(key)
        for key, _ in nodes[key]:
            _build_group(key, group)

    groups = []
    for key in nodes.keys():
        if key in visited: continue
        groups.append(set())
        _build_group(key, groups[-1])

    return groups

Return connected components - this will be a list of sets

In [None]:
from collections import defaultdict
import pandas as pd

cc_set = connected_components(raw_matches_list)
cc_set

[{1, 2, 3, 30}, {4, 5}, {6, 7, 8, 9, 23}, {10, 11, 12, 15, 17, 18, 20, 21}]

Reshape to desired format

In [None]:
# Convert list of sets to list of lists 
cc_list = []
for s in cc_set:
  cc_list.append(list(s))


# Convert to pandas data frame
cc_list2pd = pd.DataFrame({'col1': cc_list})

                                  
cust_schema = StructType([
  StructField('ecid_array', ArrayType(LongType()), True)
])


# Convert to spark data frame
df3 = spark.createDataFrame(cc_list2pd,schema=cust_schema)


# Reshape
df4 = (df3
  .withColumn('ecid_1',F.explode(F.col('ecid_array')))
  .withColumn('match_key_2',F.sort_array(F.col('ecid_array')))
  .withColumn('match_key_2', F.concat_ws('_',F.col('match_key_2')))
  .withColumn('match_key_1', F.min('ecid_1').over(W.partitionBy('match_key_2')))
  .orderBy('match_key_2','ecid_1')
  .withColumnRenamed('ecid','ecid_1')
  .select('ecid_1','match_key_1','match_key_2')
)

Result

In [None]:
df4.orderBy('match_key_1','ecid_1').show(truncate=False)

+------+-----------+-----------------------+
|ecid_1|match_key_1|match_key_2            |
+------+-----------+-----------------------+
|1     |1          |1_2_3_30               |
|2     |1          |1_2_3_30               |
|3     |1          |1_2_3_30               |
|30    |1          |1_2_3_30               |
|4     |4          |4_5                    |
|5     |4          |4_5                    |
|6     |6          |6_7_8_9_23             |
|7     |6          |6_7_8_9_23             |
|8     |6          |6_7_8_9_23             |
|9     |6          |6_7_8_9_23             |
|23    |6          |6_7_8_9_23             |
|10    |10         |10_11_12_15_17_18_20_21|
|11    |10         |10_11_12_15_17_18_20_21|
|12    |10         |10_11_12_15_17_18_20_21|
|15    |10         |10_11_12_15_17_18_20_21|
|17    |10         |10_11_12_15_17_18_20_21|
|18    |10         |10_11_12_15_17_18_20_21|
|20    |10         |10_11_12_15_17_18_20_21|
|21    |10         |10_11_12_15_17_18_20_21|
+------+--

# Connected components using PySpark dataframes

Ideally this should use checkpointing.

In [None]:
# Convert test data to required array format
t1 = (raw_matches
.withColumn('match_key1', F.array(F.col('ecid_1'),F.col('ecid_2')))
.select('match_key1')
)



# All ecid's in an array to iterate over
all_ecid = (t1
  .withColumn('group', F.lit(1))
  .groupBy('group')
  .agg(F.array_distinct(F.flatten(F.collect_list('match_key1'))).alias('ecid_array'))
  .withColumn('ecid_array', F.array_sort('ecid_array'))
  .orderBy('ecid_array')
  .drop('group')
  .persist()
)



# Count of ecid's
ecid_count = all_ecid.select(F.size('ecid_array')).collect()[0][0]


# Loop
t2 = t1
for i in range(0,ecid_count):
  print('Index: ',i)
  
  # Get distinct ecids in group
  all_ecid_count = t2.withColumn('group_by',F.lit(1)).groupBy('group_by') \
  .agg(F.flatten(F.collect_list('match_key1')).alias('match_key1')) \
  .withColumn('all',F.explode('match_key1')) \
  .groupBy('all').count() \
  .orderBy('all')
  
  # Continue if there is only one ecid left in the grouped dataset
  ecid_no = all_ecid_count.filter(F.col('all') == all_ecid.collect()[0][0][i]).select('count')
  ecid_no = ecid_no.collect()[0][0]
  if ecid_no == 1:
    continue
  
  # Identify all arrays containing the current ecid selected by the loop and group and gather to one array
  t3 = all_ecid.withColumn('ecid',F.col('ecid_array').getItem(i)).crossJoin(t2) \
  .withColumn('ecid', F.array('ecid')) \
  .withColumn('overlap',F.arrays_overlap('match_key1', 'ecid'))\
  .withColumn('group_by',F.when(F.col('overlap') == True, F.col('ecid')).otherwise(F.col('match_key1'))) \
  .groupBy('group_by') \
  .agg(F.flatten(F.collect_list('match_key1')).alias('match_key1')) \
  .withColumn('match_key1',F.array_distinct('match_key1')) \
  .withColumn('match_key1', F.array_sort('match_key1')) \
  .orderBy('match_key1') \
  .drop('group_by')
  t2 = t3

Index:  0
Index:  1
Index:  2
Index:  3
Index:  4
Index:  5
Index:  6
Index:  7
Index:  8
Index:  9
Index:  10
Index:  11
Index:  12
Index:  13
Index:  14
Index:  15
Index:  16
Index:  17
Index:  18


Result

In [None]:
t2.show(truncate=False)

+--------------------------------+
|match_key1                      |
+--------------------------------+
|[1, 2, 3, 30]                   |
|[4, 5]                          |
|[6, 7, 8, 9, 23]                |
|[10, 11, 12, 15, 17, 18, 20, 21]|
+--------------------------------+



## References

[Checkpointing 1](https://enigma.com/blog/post/things-i-wish-id-known-about-spark)  
[Checkpointing 2](https://dzone.com/articles/what-are-spark-checkpoints-on-dataframes)  
[Checkpointing 3](https://jaceklaskowski.gitbooks.io/mastering-spark-sql/content/spark-sql-checkpointing.html)  
[Dataframe implementation (Scala)](https://blogs.oracle.com/ai-and-datascience/post/graph-computations-with-apache-spark)  
[Dataframe implementation (PySpark)](https://towardsdatascience.com/connected-components-at-scale-in-pyspark-4a1c6423b9ed)  


In [None]:
!ls

sample_data  spark-3.0.2-bin-hadoop2.7
