In [1]:
from pyspark.sql import SparkSession

In [2]:
spark = ( SparkSession
         .builder
         .appName("SortMergeJoin")
         .getOrCreate()
)

23/05/25 20:11:59 WARN Utils: Your hostname, wedivv-H110M-S2V resolves to a loopback address: 127.0.1.1; using 192.168.1.44 instead (on interface wlp5s0)
23/05/25 20:11:59 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
23/05/25 20:12:00 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
import random

In [4]:
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)

In [5]:
spark.conf.get("spark.sql.autoBroadcastJoinThreshold")


'-1'

In [6]:
states = {0: "AZ", 1: "CO", 2: "CA", 3: "TX", 4: "NY", 5: "MI"}
items = {0: "SKU-0", 1: "SKU-1", 2: "SKU-2", 3: "SKU-3", 4: "SKU-4", 5: "SKU-5"}

In [7]:
users_data = [(id, f"user_{id}", f"user_{id}@databricks.com", random.choice(states))
              for id in range(1000001)]

usersDF = spark.createDataFrame(users_data, ["uid", "login", "email", "user_state"]).repartition("uid")

In [8]:
orders_data = [(r, r, random.randint(0, 9999), 10 * r * 0.2, random.choice(states), random.choice(items))
               for r in range(1000001)]

ordersDF = spark.createDataFrame(orders_data, ["transaction_id", "quantity", "users_id", "amount", "state", "items"]).repartition("users_id")

In [9]:
usersOrdersDF = ordersDF.join(usersDF, ordersDF.users_id == usersDF.uid)

In [10]:
usersOrdersDF.show(truncate=False)

23/05/25 20:12:56 WARN TaskSetManager: Stage 0 contains a task of very large size (6839 KiB). The maximum recommended task size is 1000 KiB.
23/05/25 20:12:58 WARN TaskSetManager: Stage 1 contains a task of very large size (12604 KiB). The maximum recommended task size is 1000 KiB.

+--------------+--------+--------+--------+-----+-----+---+------+---------------------+----------+
|transaction_id|quantity|users_id|amount  |state|items|uid|login |email                |user_state|
+--------------+--------+--------+--------+-----+-----+---+------+---------------------+----------+
|5790          |5790    |0       |11580.0 |CO   |SKU-5|0  |user_0|user_0@databricks.com|TX        |
|7482          |7482    |0       |14964.0 |NY   |SKU-0|0  |user_0|user_0@databricks.com|TX        |
|10647         |10647   |0       |21294.0 |AZ   |SKU-2|0  |user_0|user_0@databricks.com|TX        |
|30588         |30588   |0       |61176.0 |TX   |SKU-1|0  |user_0|user_0@databricks.com|TX        |
|31361         |31361   |0       |62722.0 |CA   |SKU-1|0  |user_0|user_0@databricks.com|TX        |
|36367         |36367   |0       |72734.0 |CA   |SKU-2|0  |user_0|user_0@databricks.com|TX        |
|45868         |45868   |0       |91736.0 |CO   |SKU-5|0  |user_0|user_0@databricks.com|TX        |


                                                                                

In [11]:
usersOrdersDF.explain()

== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- SortMergeJoin [users_id#10L], [uid#0L], Inner
   :- Sort [users_id#10L ASC NULLS FIRST], false, 0
   :  +- Exchange hashpartitioning(users_id#10L, 200), REPARTITION_BY_COL, [plan_id=143]
   :     +- Filter isnotnull(users_id#10L)
   :        +- Scan ExistingRDD[transaction_id#8L,quantity#9L,users_id#10L,amount#11,state#12,items#13]
   +- Sort [uid#0L ASC NULLS FIRST], false, 0
      +- Exchange hashpartitioning(uid#0L, 200), REPARTITION_BY_COL, [plan_id=146]
         +- Filter isnotnull(uid#0L)
            +- Scan ExistingRDD[uid#0L,login#1,email#2,user_state#3]




# Optimizing the shuffle sort merge join

In [12]:
from pyspark.sql.functions import *

In [13]:
spark.sql("DROP TABLE IF EXISTS UsersTbl")


DataFrame[]

In [14]:
(
    usersDF.orderBy(asc("uid"))
    .write.format("parquet")
    .bucketBy(8, "uid")
    .mode('overwrite')
    .saveAsTable("UsersTbl")
)


23/05/25 20:13:01 WARN TaskSetManager: Stage 5 contains a task of very large size (12604 KiB). The maximum recommended task size is 1000 KiB.
                                                                                

In [15]:
(
ordersDF.orderBy(asc("users_id"))
    .write.format("parquet")
    .bucketBy(8, "users_id")
    .mode('Overwrite')
    .saveAsTable("OrdersTbl")
)

23/05/25 20:13:07 WARN TaskSetManager: Stage 13 contains a task of very large size (6839 KiB). The maximum recommended task size is 1000 KiB.
                                                                                

In [16]:
spark.sql("CACHE TABLE UsersTbl")
spark.sql("CACHE TABLE OrdersTbl")

                                                                                

DataFrame[]

In [17]:
# Read them back in
usersBucketDF = spark.table("UsersTbl")
ordersBucketDF = spark.table("OrdersTbl")

In [18]:
joinUsersOrdersBucketDF = ordersBucketDF \
    .join(usersBucketDF, ordersBucketDF["users_id"] == usersBucketDF["uid"])

In [19]:
joinUsersOrdersBucketDF.show(truncate=False)

+--------------+--------+--------+--------+-----+-----+---+------+---------------------+----------+
|transaction_id|quantity|users_id|amount  |state|items|uid|login |email                |user_state|
+--------------+--------+--------+--------+-----+-----+---+------+---------------------+----------+
|16668         |16668   |2       |33336.0 |MI   |SKU-3|2  |user_2|user_2@databricks.com|NY        |
|35005         |35005   |2       |70010.0 |CA   |SKU-0|2  |user_2|user_2@databricks.com|NY        |
|37983         |37983   |2       |75966.0 |CO   |SKU-3|2  |user_2|user_2@databricks.com|NY        |
|46158         |46158   |2       |92316.0 |TX   |SKU-3|2  |user_2|user_2@databricks.com|NY        |
|52445         |52445   |2       |104890.0|CO   |SKU-0|2  |user_2|user_2@databricks.com|NY        |
|54882         |54882   |2       |109764.0|AZ   |SKU-0|2  |user_2|user_2@databricks.com|NY        |
|86650         |86650   |2       |173300.0|NY   |SKU-5|2  |user_2|user_2@databricks.com|NY        |


5.6sec vs 0.6sec

In [20]:
joinUsersOrdersBucketDF.explain()

== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- SortMergeJoin [users_id#230L], [uid#91L], Inner
   :- Sort [users_id#230L ASC NULLS FIRST], false, 0
   :  +- Filter isnotnull(users_id#230L)
   :     +- Scan In-memory table OrdersTbl [transaction_id#228L, quantity#229L, users_id#230L, amount#231, state#232, items#233], [isnotnull(users_id#230L)]
   :           +- InMemoryRelation [transaction_id#228L, quantity#229L, users_id#230L, amount#231, state#232, items#233], StorageLevel(disk, memory, deserialized, 1 replicas)
   :                 +- *(1) ColumnarToRow
   :                    +- FileScan parquet spark_catalog.default.orderstbl[transaction_id#228L,quantity#229L,users_id#230L,amount#231,state#232,items#233] Batched: true, Bucketed: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/home/wedivv/Code/others/zDEJourney/spark/LearningSpark/spark-war..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct<transaction_id:bigint,quantity: