In [None]:
# Spark Session
from pyspark.sql import SparkSession

spark = (
    SparkSession
    .builder
    .appName("Understand Plans and DAG")
    .master("local[*]")
    .getOrCreate()
)

spark

In [None]:
# Disable AQE and Broadcast join

spark.conf.set("spark.sql.adaptive.enabled", False)
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", False)
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)

In [None]:
# Check default Parallism

spark.sparkContext.defaultParallelism

2

In [None]:
# Create dataframes

df_1 = spark.range(4, 200, 2)
df_2 = spark.range(2, 200, 4)

In [None]:
df_2.rdd.getNumPartitions()

2

In [None]:
# Re-partition data

df_3 = df_1.repartition(5)
df_4 = df_2.repartition(7)

In [None]:
df_4.rdd.getNumPartitions()

7

In [None]:
# Join the dataframes

df_joined = df_3.join(df_4, on="id")

In [None]:
# Get the sum of ids

df_sum = df_joined.selectExpr("sum(id) as total_sum")

In [None]:
# View data
df_sum.show()

+---------+
|total_sum|
+---------+
|     4998|
+---------+



In [None]:
# Explain plan

df_sum.explain()

== Physical Plan ==
*(6) HashAggregate(keys=[], functions=[sum(id#0L)])
+- Exchange SinglePartition, ENSURE_REQUIREMENTS, [plan_id=194]
   +- *(5) HashAggregate(keys=[], functions=[partial_sum(id#0L)])
      +- *(5) Project [id#0L]
         +- *(5) SortMergeJoin [id#0L], [id#2L], Inner
            :- *(2) Sort [id#0L ASC NULLS FIRST], false, 0
            :  +- Exchange hashpartitioning(id#0L, 200), ENSURE_REQUIREMENTS, [plan_id=178]
            :     +- Exchange RoundRobinPartitioning(5), REPARTITION_BY_NUM, [plan_id=177]
            :        +- *(1) Range (4, 200, step=2, splits=2)
            +- *(4) Sort [id#2L ASC NULLS FIRST], false, 0
               +- Exchange hashpartitioning(id#2L, 200), ENSURE_REQUIREMENTS, [plan_id=185]
                  +- Exchange RoundRobinPartitioning(7), REPARTITION_BY_NUM, [plan_id=184]
                     +- *(3) Range (2, 200, step=4, splits=2)




In [None]:
# Union the data again to see the skipped stages

df_union = df_sum.union(df_4)

In [None]:
df_union.show()

+---------+
|total_sum|
+---------+
|     4998|
|       14|
|       86|
|       42|
|      146|
|      134|
|      142|
|      162|
|       74|
|       94|
|       34|
|      198|
|      182|
|      126|
|      174|
|       98|
|       10|
|       82|
|      122|
|      186|
+---------+
only showing top 20 rows



In [None]:
# Explain plan

df_union.explain()

== Physical Plan ==
Union
:- *(6) HashAggregate(keys=[], functions=[sum(id#0L)])
:  +- Exchange SinglePartition, ENSURE_REQUIREMENTS, [plan_id=432]
:     +- *(5) HashAggregate(keys=[], functions=[partial_sum(id#0L)])
:        +- *(5) Project [id#0L]
:           +- *(5) SortMergeJoin [id#0L], [id#2L], Inner
:              :- *(2) Sort [id#0L ASC NULLS FIRST], false, 0
:              :  +- Exchange hashpartitioning(id#0L, 200), ENSURE_REQUIREMENTS, [plan_id=416]
:              :     +- Exchange RoundRobinPartitioning(5), REPARTITION_BY_NUM, [plan_id=415]
:              :        +- *(1) Range (4, 200, step=2, splits=2)
:              +- *(4) Sort [id#2L ASC NULLS FIRST], false, 0
:                 +- Exchange hashpartitioning(id#2L, 200), ENSURE_REQUIREMENTS, [plan_id=423]
:                    +- Exchange RoundRobinPartitioning(7), REPARTITION_BY_NUM, [plan_id=422]
:                       +- *(3) Range (2, 200, step=4, splits=2)
+- ReusedExchange [id#24L], Exchange RoundRobinPartitioning(

In [None]:
# DataFrame to RDD

df_1.rdd

MapPartitionsRDD[20] at javaToPython at NativeMethodAccessorImpl.java:0