## Shuffling ("exchange" in query plan): produce datasets that can be transfered onto separate data nodes, and manipulated as sub-tasks (e.g. sub-joins) to avoid transfering all the data multiple times across joins, etc.
## one dataset would have all the ROWS of one or more value of the join key, making sure that join will have full sub-result of that value

## used in operations such as groupby(), join(), repartition(), distinct()

## smaller datasets (not compared to each other but in spark settings "spark.sql.autoBroadcastJoinThreshold") are broadcasted so that shuffling would not happen!

In [1]:
from pyspark.sql import *

spark = (SparkSession.builder
         #.master("local")
         .appName("Shuffle perf test with unordered vs ordered join keys")
         .enableHiveSupport() # enableHiveSupport() needed to make data persistent...
         .config("spark.log.level", "ERROR")
         .config("spark.executor.memory", "4g")
         .config("spark.driver.memory", "4g")
         .config("spark.sql.autoBroadcastJoinThreshold", -1) # max. byte size of table to be broadcast. -1 = disable
         # enable join reorder (doesn't seem to work...)
         .config("spark.sql.cbo.enabled", True) # default is false
         .config("spark.sql.cbo.joinReorder.enabled", True) # default is false
         .config("spark.sql.cbo.joinReorder.dp.threshold", 20) # max. number of jon reorders. default is 12
         .config("spark.sql.statistics.histogram.enabled", True)
         .getOrCreate())

print('spark version:', spark.version)
print("Done.")

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/05/11 20:05:00 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Setting Spark log level to "ERROR".


spark version: 3.5.5
Done.


In [2]:
# generate some test data
from pyspark.sql import functions as F
import random
import math

num_airports =   1000 # max. 26^4
num_flights  = 1000000
num_delays   = 1000
num_alerts   = 1000

# origin_airport
letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
num_airports = min(num_airports, pow(len(letters), 4))
origin_airport_lst = [
    ( 
        
        ''.join([letters[c1 // pow(len(letters), c) % len(letters)] for c in range(4)]),                                # airport (ICAO code)
        (
            ''.join([letters[c1 // pow(len(letters), c) % len(letters)] for c in range(4)]) +                              # display_airport_name
            ' ' +
            ''.join([letters[random.randint(0, len(letters)-1)] for c in range(random.randint(5, 10))])
        ).title(),
     )
    for c1 in range(num_airports)]

origin_airport_df = spark.createDataFrame(origin_airport_lst, ['airport', 'display_airport_name']).select(
    F.col('airport').alias('origin'), 
    F.col('display_airport_name').alias('origin_display_name'))
origin_airport_df.show(5, truncate=False)

# flights
flights_lst = [
    (c1,                                                               # flight_id
     origin_airport_lst[c1%(len(origin_airport_lst))][0],              # origin
     origin_airport_lst[(c1 + 1)%(len(origin_airport_lst))][0]         # dest
     )
    for c1 in range(num_flights)]
flights_df = spark.createDataFrame(flights_lst, ['flight_id', 'origin', 'dest'])
flights_df.show(5, truncate=False)

def non_negative(n):
    return 0 if n < 0 else n

# delays
delays_lst = [
    (c1,                                            # flight_id
     non_negative(random.randint(-40, 35)),        # dep_delay_mins
     )
    for c1 in range(num_delays)]
delays_df = spark.createDataFrame(delays_lst, ['flight_id', 'dep_delay_mins'])
delays_df.show(5, truncate=False)

# flight_alerts
flight_alerts_lst = [
    (c1,                                                   # flight_id
     "some alert!" if random.randint(1,10) <= 3 else ""    # alert
     )
    for c1 in range(num_alerts)]
flight_alerts_df = spark.createDataFrame(flight_alerts_lst, ['flight_id', 'alert'])
flight_alerts_df.show(5, truncate=False)
print("Done.")

                                                                                

+------+-------------------+
|origin|origin_display_name|
+------+-------------------+
|AAAA  |Aaaa Ntldjex       |
|BAAA  |Baaa Jqlvv         |
|CAAA  |Caaa Izlyxr        |
|DAAA  |Daaa Rlvfr         |
|EAAA  |Eaaa Aaymq         |
+------+-------------------+
only showing top 5 rows

+---------+------+----+
|flight_id|origin|dest|
+---------+------+----+
|0        |AAAA  |BAAA|
|1        |BAAA  |CAAA|
|2        |CAAA  |DAAA|
|3        |DAAA  |EAAA|
|4        |EAAA  |FAAA|
+---------+------+----+
only showing top 5 rows

+---------+--------------+
|flight_id|dep_delay_mins|
+---------+--------------+
|0        |0             |
|1        |0             |
|2        |0             |
|3        |6             |
|4        |23            |
+---------+--------------+
only showing top 5 rows

+---------+-----------+
|flight_id|alert      |
+---------+-----------+
|0        |           |
|1        |some alert!|
|2        |           |
|3        |           |
|4        |           |
+---------+--

In [3]:
# dataframe join test

from datetime import datetime

# join keys VARY in execution order => reshuffling is necessary!
ct = datetime.now()
result_df = (
    flights_df
    .select("flight_id", "origin", "dest")
    .join(delays_df, "flight_id", "left") # flight_id shuffle
    .join(origin_airport_df, "origin", "left") # origin shuffle
    .join(flight_alerts_df, "flight_id", "left") # flight_id re-shuffled
)

#result_df = result_df.na.fill(value="", subset = ['alert'])
result_df.write.mode("overwrite").csv("result_df1") # to make sure all results are used, not only a few that's shown
print(f"Unsorted join key done in {(datetime.now() - ct)} secs.")

# join keys DO NOT VARY in execution order => re-shuffling is NOT necessary!
ct = datetime.now()
result_df = (
    flights_df
    .select("flight_id", "origin", "dest")
    .join(origin_airport_df, "origin", "left") # origin shuffle
    .join(delays_df, "flight_id", "left") # flight_id shuffle
    .join(flight_alerts_df, "flight_id", "left") # flight_id, no re-shuffling needed
)

#result_df = result_df.na.fill(value="", subset = ['alert'])
result_df.write.mode("overwrite").csv("result_df2") # to make sure all results are used, not only a few that's shown
print(f"Sorted join key done in {(datetime.now() - ct)} secs.")

                                                                                

Unsorted join key done in 0:00:22.142348 secs.




Sorted join key done in 0:00:15.125590 secs.


                                                                                

In [5]:
# sql join test - join order is kept the same as defined in the query by default and unable to change that...!

from datetime import datetime

# enable join reordering!
spark.conf.set("spark.sql.cbo.enabled", True) # default is false
spark.conf.set("spark.sql.cbo.joinReorder.enabled", True) # default is false
spark.conf.set("spark.sql.cbo.joinReorder.dp.threshold", 20) # max. number of jon reorders. default is 12


spark.sql('USE MYTESTDB')
origin_airport_df.createOrReplaceTempView('origin_airport')
flights_df.createOrReplaceTempView('flights')
delays_df.createOrReplaceTempView('delays')
flight_alerts_df.createOrReplaceTempView('flight_alerts')

# cache table for ANALYZE
spark.catalog.cacheTable("origin_airport")
spark.catalog.cacheTable("flights")
spark.catalog.cacheTable("delays")
spark.catalog.cacheTable("flight_alerts")

# ANALYZE for join reordering
spark.sql("ANALYZE TABLE origin_airport COMPUTE STATISTICS FOR COLUMNS origin")
spark.sql("ANALYZE TABLE flights COMPUTE STATISTICS FOR COLUMNS flight_id")
spark.sql("ANALYZE TABLE delays COMPUTE STATISTICS FOR COLUMNS flight_id")
spark.sql("ANALYZE TABLE flight_alerts COMPUTE STATISTICS FOR COLUMNS flight_id")

# join keys VARY in execution order => reshuffling is necessary!
ct = datetime.now()
result_df = spark.sql('''
SELECT f.flight_id, f.origin, f.dest
FROM flights f
    left outer join delays d
    on (f.flight_id = d.flight_id)
    left outer join origin_airport oa
    on (f.origin = oa.origin)
    left outer join flight_alerts fa
    on (f.flight_id = fa.flight_id)
''')
result_df.write.mode("overwrite").csv("result_df1") # action
print(f"Unsorted join key done in {(datetime.now() - ct)} secs.")

# join keys DO NOT VARY in execution order => re-shuffling is NOT necessary!
ct = datetime.now()
result_df = spark.sql('''
SELECT f.flight_id, f.origin, f.dest
FROM flights f
    left outer join origin_airport oa
    on (f.origin = oa.origin)
    left outer join delays d
    on (f.flight_id = d.flight_id)
    left outer join flight_alerts fa
    on (f.flight_id = fa.flight_id)
''')
result_df.write.mode("overwrite").csv("result_df2") # action
print(f"Sorted join key done in {(datetime.now() - ct)} secs.")

                                                                                

Unsorted join key done in 0:00:05.250045 secs.




Sorted join key done in 0:00:02.723144 secs.


                                                                                

In [4]:
spark.stop()
print('Done.')

Done.
