In [23]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import *

spark = SparkSession.builder.appName("repartition").master("local[*]").getOrCreate()


# Set spark conf to skip broadcast join
spark.conf.set("spark.sql.adaptive.enabled", "false")
spark.conf.set("spark.sql.autoBroadcastJoinThreshold",-1)

In [24]:
df_employee = spark.read.option("header",True).option("inferSchema",True).csv("employee_records.txt")

In [26]:
df_employee.show(5)

+----------+---------+--------------------+-------------------+--------------------+------------------+------+-------------+
|first_name|last_name|           job_title|                dob|               email|             phone|salary|department_id|
+----------+---------+--------------------+-------------------+--------------------+------------------+------+-------------+
|   Richard| Morrison|Public relations ...|1973-05-05 00:00:00|melissagarcia@exa...|     (699)525-4827|512653|            8|
|     Bobby| Mccarthy|   Barrister's clerk|1974-04-25 00:00:00|   llara@example.net|(750)846-1602x7458|999836|            7|
|    Dennis|   Norman|Land/geomatics su...|1990-06-24 00:00:00| jturner@example.net|  873.820.0518x825|131900|           10|
|      John|   Monroe|        Retail buyer|1968-06-16 00:00:00|  erik33@example.net|  820-813-0557x624|485506|            1|
|  Michelle|  Elliott|      Air cabin crew|1975-03-31 00:00:00|tiffanyjohnston@e...|     (705)900-5337|604738|            8|


In [27]:
# Current number of partition in the dataframe
df_employee.rdd.getNumPartitions()

24

In [28]:
# Distribution of records in partitions
df_employee.withColumn("partition_id", spark_partition_id()).groupBy("partition_id").count().show(24)

+------------+-----+
|partition_id|count|
+------------+-----+
|          12|43092|
|          22|43025|
|           1|43062|
|          13|43031|
|           6|43014|
|          16|43061|
|           3|43061|
|          20|43031|
|           5|43070|
|          19|43044|
|          15|43058|
|           9|43027|
|          17|43055|
|           4|43046|
|           8|43091|
|          23| 9828|
|           7|43067|
|          10|43065|
|          21|43022|
|          11|43069|
|          14|43033|
|           2|43054|
|           0|43075|
|          18|43019|
+------------+-----+



In [29]:
df_dep = spark.read.option("header",True).option("inferSchema",True).csv("department_data.txt")

In [47]:
# join which forces shuffle and created spark.sql.shuffle.partitions number of partitions
df_joined =  df_employee.join(df_dep,on="department_id",how="inner")

In [48]:
df_joined.show(5)

+-------------+-----------+---------+--------------------+-------------------+--------------------+--------------------+------+---------------+--------------------+------------+-----+-------------------+
|department_id| first_name|last_name|           job_title|                dob|               email|               phone|salary|department_name|         description|        city|state|            country|
+-------------+-----------+---------+--------------------+-------------------+--------------------+--------------------+------+---------------+--------------------+------------+-----+-------------------+
|            1|       John|   Monroe|        Retail buyer|1968-06-16 00:00:00|  erik33@example.net|    820-813-0557x624|485506|    Bryan-James|Optimized disinte...|Melissaburgh|   FM|Trinidad and Tobago|
|            1|    Rachael|Rodriguez|         Media buyer|1966-12-02 00:00:00|griffinmary@examp...| +1-791-344-7586x548|544732|    Bryan-James|Optimized disinte...|Melissaburgh|   FM|T

In [53]:
# default 200 partitions after shuffle
df_joined.rdd.getNumPartitions()

200

In [50]:
# Due to large num of partitions, uneven distrbution and lot of empty partitions
df_joined.withColumn("partition_id", spark_partition_id()).groupBy("partition_id").count().show()

+------------+------+
|partition_id| count|
+------------+------+
|         103|100417|
|         122| 99780|
|          43| 99451|
|         107| 99805|
|          49| 99706|
|          51|100248|
|         102|100214|
|          66|100210|
|         174|100155|
|          89|100014|
+------------+------+



In [56]:
# repartition the data to achieve uniformity
df_rep = df_joined.repartition(20)
df_rep.withColumn("partition_id", spark_partition_id()).groupBy("partition_id").count().show()

+------------+-----+
|partition_id|count|
+------------+-----+
|          12|50000|
|           1|49998|
|          13|50000|
|           6|50002|
|          16|50000|
|           3|49999|
|           5|50002|
|          19|49999|
|          15|50000|
|           9|50000|
|          17|50001|
|           4|50001|
|           8|50000|
|           7|50001|
|          10|50001|
|          11|50000|
|          14|50000|
|           2|49998|
|           0|49999|
|          18|49999|
+------------+-----+



In [67]:
# Lets try to perform repartition on a particular column
df_names = df_rep.repartition(10,"first_name")
df_names.withColumn("partition_id", spark_partition_id()).groupBy("partition_id").count().show(30)

# we can see that it now the data is not uniformly distributed. So when we give a key to perform repartition, it can lead to data skewness

+------------+------+
|partition_id| count|
+------------+------+
|           1|155967|
|           6| 95853|
|           3| 90682|
|           5| 89830|
|           9|105484|
|           4| 96771|
|           8| 89842|
|           7| 76695|
|           2|100432|
|           0| 98444|
+------------+------+



In [73]:
df_coal = df_names.coalesce(5)
df_coal.withColumn("partition_id", spark_partition_id()).groupBy("partition_id").count().show(5)

+------------+------+
|partition_id| count|
+------------+------+
|           1|191114|
|           3|172548|
|           4|195326|
|           2|186601|
|           0|254411|
+------------+------+



In [74]:
df_coal.explain()
# we can see that coalesce did not perform shuffling

== Physical Plan ==
Coalesce 5
+- Exchange hashpartitioning(first_name#343, 10), REPARTITION_BY_NUM, [id=#2998]
   +- *(5) Project [department_id#350, first_name#343, last_name#344, job_title#345, dob#346, email#347, phone#348, salary#349, department_name#494, description#495, city#496, state#497, country#498]
      +- *(5) SortMergeJoin [department_id#350], [department_id#493], Inner
         :- *(2) Sort [department_id#350 ASC NULLS FIRST], false, 0
         :  +- Exchange hashpartitioning(department_id#350, 200), ENSURE_REQUIREMENTS, [id=#2982]
         :     +- *(1) Filter isnotnull(department_id#350)
         :        +- FileScan csv [first_name#343,last_name#344,job_title#345,dob#346,email#347,phone#348,salary#349,department_id#350] Batched: false, DataFilters: [isnotnull(department_id#350)], Format: CSV, Location: InMemoryFileIndex(1 paths)[file:/home/jupyter/data/employee_records.txt], PartitionFilters: [], PushedFilters: [IsNotNull(department_id)], ReadSchema: struct<first_nam

In [76]:
# lets try aggregation on data which is partitioned by a column
df_names_agg = df_names.groupBy("first_name").sum("salary")
df_names_agg.explain()

== Physical Plan ==
*(6) HashAggregate(keys=[first_name#343], functions=[sum(salary#349)])
+- *(6) HashAggregate(keys=[first_name#343], functions=[partial_sum(salary#349)])
   +- Exchange hashpartitioning(first_name#343, 10), REPARTITION_BY_NUM, [id=#3099]
      +- *(5) Project [first_name#343, salary#349]
         +- *(5) SortMergeJoin [department_id#350], [department_id#493], Inner
            :- *(2) Sort [department_id#350 ASC NULLS FIRST], false, 0
            :  +- Exchange hashpartitioning(department_id#350, 200), ENSURE_REQUIREMENTS, [id=#3083]
            :     +- *(1) Filter isnotnull(department_id#350)
            :        +- FileScan csv [first_name#343,salary#349,department_id#350] Batched: false, DataFilters: [isnotnull(department_id#350)], Format: CSV, Location: InMemoryFileIndex(1 paths)[file:/home/jupyter/data/employee_records.txt], PartitionFilters: [], PushedFilters: [IsNotNull(department_id)], ReadSchema: struct<first_name:string,salary:int,department_id:int>
      

In [77]:
# Lets try same aggregation on data not partitioned by that column
df_joined_agg = df_joined.groupBy("first_name").sum("salary")
df_joined_agg.explain()
# we can see there is a Exchange hashpartitioning in the second line now

== Physical Plan ==
*(6) HashAggregate(keys=[first_name#343], functions=[sum(salary#349)])
+- Exchange hashpartitioning(first_name#343, 200), ENSURE_REQUIREMENTS, [id=#3200]
   +- *(5) HashAggregate(keys=[first_name#343], functions=[partial_sum(salary#349)])
      +- *(5) Project [first_name#343, salary#349]
         +- *(5) SortMergeJoin [department_id#350], [department_id#493], Inner
            :- *(2) Sort [department_id#350 ASC NULLS FIRST], false, 0
            :  +- Exchange hashpartitioning(department_id#350, 200), ENSURE_REQUIREMENTS, [id=#3183]
            :     +- *(1) Filter isnotnull(department_id#350)
            :        +- FileScan csv [first_name#343,salary#349,department_id#350] Batched: false, DataFilters: [isnotnull(department_id#350)], Format: CSV, Location: InMemoryFileIndex(1 paths)[file:/home/jupyter/data/employee_records.txt], PartitionFilters: [], PushedFilters: [IsNotNull(department_id)], ReadSchema: struct<first_name:string,salary:int,department_id:int>
    