In [2]:
# Spark Session
from pyspark.sql import SparkSession

spark = (
    SparkSession
    .builder
    .appName("Optimizing Joins")
    .master("local[*]")
    .config("spark.cores.max", 16)
    .config("spark.executor.cores", 4)
    .config("spark.executor.memory", "512M")
    .getOrCreate()
)

spark

In [3]:
# Disable AQE and Broadcast join

spark.conf.set("spark.sql.adaptive.enabled", False)
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", False)
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)

In [6]:
# Join big and small tables - SortMerge vs BroadCast Join

# Read EMP CSV data

_schema = "first_name string, last_name string, job_title string, dob string, email string, phone string, salary double, department_id int"
emp = spark.read.format("csv").schema(_schema).option("header", True).load("data/input/employee_records.csv")

# Read DEPT CSV data

_dept_schema = "department_id int, department_name string, description string, city string, state string, country string"
dept = spark.read.format("csv").schema(_dept_schema).option("header", True).load("data/input/department_data.csv")


In [7]:
# Join Datasets using SortMerge Join 

df_joined = emp.join(dept, on=emp.department_id==dept.department_id, how="left_outer")
df_joined.write.format("noop").mode("overwrite").save()

                                                                                

In [8]:
df_joined.explain()

== Physical Plan ==
*(4) SortMergeJoin [department_id#39], [department_id#48], LeftOuter
:- *(1) Sort [department_id#39 ASC NULLS FIRST], false, 0
:  +- Exchange hashpartitioning(department_id#39, 200), ENSURE_REQUIREMENTS, [plan_id=70]
:     +- FileScan csv [first_name#32,last_name#33,job_title#34,dob#35,email#36,phone#37,salary#38,department_id#39] Batched: false, DataFilters: [], Format: CSV, Location: InMemoryFileIndex(1 paths)[file:/Users/vrushabh.deokar/pysparkBasics/data/input/employee_records...., PartitionFilters: [], PushedFilters: [], ReadSchema: struct<first_name:string,last_name:string,job_title:string,dob:string,email:string,phone:string,s...
+- *(3) Sort [department_id#48 ASC NULLS FIRST], false, 0
   +- Exchange hashpartitioning(department_id#48, 200), ENSURE_REQUIREMENTS, [plan_id=82]
      +- *(2) Filter isnotnull(department_id#48)
         +- FileScan csv [department_id#48,department_name#49,description#50,city#51,state#52,country#53] Batched: false, DataFilters: [is

In [11]:
# Join Datasets using broadcast
# Preferred only when one table is small and other is big
from pyspark.sql.functions import broadcast

df_joined = emp.join(broadcast(dept), on=emp.department_id==dept.department_id, how="left_outer")
df_joined.write.format("noop").mode("overwrite").save()

                                                                                

In [10]:
df_joined.explain()

== Physical Plan ==
*(2) BroadcastHashJoin [department_id#39], [department_id#48], LeftOuter, BuildRight, false
:- FileScan csv [first_name#32,last_name#33,job_title#34,dob#35,email#36,phone#37,salary#38,department_id#39] Batched: false, DataFilters: [], Format: CSV, Location: InMemoryFileIndex(1 paths)[file:/Users/vrushabh.deokar/pysparkBasics/data/input/employee_records...., PartitionFilters: [], PushedFilters: [], ReadSchema: struct<first_name:string,last_name:string,job_title:string,dob:string,email:string,phone:string,s...
+- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint)),false), [plan_id=160]
   +- *(1) Filter isnotnull(department_id#48)
      +- FileScan csv [department_id#48,department_name#49,description#50,city#51,state#52,country#53] Batched: false, DataFilters: [isnotnull(department_id#48)], Format: CSV, Location: InMemoryFileIndex(1 paths)[file:/Users/vrushabh.deokar/pysparkBasics/data/input/department_data.csv], PartitionFilters: 

In [12]:
# Join Big and Big table - SortMerge without Buckets

In [14]:
# Read Sales data

sales_schema = "transacted_at string, trx_id string, retailer_id string, description string, amount double, city_id string"

sales = spark.read.format("csv").schema(sales_schema).option("header", True).load("data/input/new_sales.csv")

In [16]:
# Read City data

city_schema = "city_id string, city string, state string, state_abv string, country string"

city = spark.read.format("csv").schema(city_schema).option("header", True).load("data/input/cities.csv")

In [17]:
# Join Datasets using SortMerge Join 

df_joined = sales.join(city, on=sales.city_id==city.city_id, how="left_outer")
df_joined.write.format("noop").mode("overwrite").save()

                                                                                

In [18]:
df_joined.explain()

== Physical Plan ==
*(4) SortMergeJoin [city_id#235], [city_id#242], LeftOuter
:- *(1) Sort [city_id#235 ASC NULLS FIRST], false, 0
:  +- Exchange hashpartitioning(city_id#235, 200), ENSURE_REQUIREMENTS, [plan_id=279]
:     +- FileScan csv [transacted_at#230,trx_id#231,retailer_id#232,description#233,amount#234,city_id#235] Batched: false, DataFilters: [], Format: CSV, Location: InMemoryFileIndex(1 paths)[file:/Users/vrushabh.deokar/pysparkBasics/data/input/new_sales.csv], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<transacted_at:string,trx_id:string,retailer_id:string,description:string,amount:double,cit...
+- *(3) Sort [city_id#242 ASC NULLS FIRST], false, 0
   +- Exchange hashpartitioning(city_id#242, 200), ENSURE_REQUIREMENTS, [plan_id=291]
      +- *(2) Filter isnotnull(city_id#242)
         +- FileScan csv [city_id#242,city#243,state#244,state_abv#245,country#246] Batched: false, DataFilters: [isnotnull(city_id#242)], Format: CSV, Location: InMemoryFileIndex(1 pat

In [21]:
# Join Datasets using SortMerge Join 
from pyspark.sql.functions import broadcast

df_joined = sales.join(broadcast(city), on=sales.city_id==city.city_id, how="left_outer")
df_joined.write.format("noop").mode("overwrite").save()

# Expected to fail due to Out of Memeory Issue 
# This leads to either spill the data on Disk or increase input size

25/01/13 18:39:51 WARN MemoryStore: Not enough space to cache broadcast_27 in memory! (computed 432.0 MiB so far)
25/01/13 18:39:51 WARN BlockManager: Persisting block broadcast_27 to disk instead.
25/01/13 18:39:53 WARN MemoryStore: Not enough space to cache broadcast_27 in memory! (computed 432.0 MiB so far)
                                                                                

In [20]:
df_joined.explain()

== Physical Plan ==
*(2) BroadcastHashJoin [city_id#235], [city_id#242], LeftOuter, BuildRight, false
:- FileScan csv [transacted_at#230,trx_id#231,retailer_id#232,description#233,amount#234,city_id#235] Batched: false, DataFilters: [], Format: CSV, Location: InMemoryFileIndex(1 paths)[file:/Users/vrushabh.deokar/pysparkBasics/data/input/new_sales.csv], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<transacted_at:string,trx_id:string,retailer_id:string,description:string,amount:double,cit...
+- BroadcastExchange HashedRelationBroadcastMode(List(input[0, string, false]),false), [plan_id=369]
   +- *(1) Filter isnotnull(city_id#242)
      +- FileScan csv [city_id#242,city#243,state#244,state_abv#245,country#246] Batched: false, DataFilters: [isnotnull(city_id#242)], Format: CSV, Location: InMemoryFileIndex(1 paths)[file:/Users/vrushabh.deokar/pysparkBasics/data/input/cities.csv], PartitionFilters: [], PushedFilters: [IsNotNull(city_id)], ReadSchema: struct<city_id:string,cit

In [None]:
# Note: Spark murmur hasing 

In [28]:
# Write Sales and City data in Buckets
# import os
# os.makedirs("/data/input/datasets/", exist_ok=True)

sales.write.format("csv").mode("overwrite").bucketBy(4, "city_id").option("header", True).option("path", "/data/input/datasets/sales_bucket.csv").saveAsTable("sales_bucket")

# Total_bukcet_created = 8 (parttions) * 4 (buckets in each partitions) = 32 buckets


                                                                                

In [30]:
# Write City data in Buckets

city.write.format("csv").mode("overwrite").bucketBy(4, "city_id").option("header", True).option("path", "data/input/datasets/city_bucket.csv").saveAsTable("city_bucket")

                                                                                

In [32]:
# if spark.catalog.tableExists("city_bucket"):
#     print("Table sales_bukcet exists.")
# else:
#     print("Table sales_bukcet does not exist.")

# Check tables

spark.sql("show tables in default").show()

+---------+------------+-----------+
|namespace|   tableName|isTemporary|
+---------+------------+-----------+
|  default| city_bucket|      false|
|  default|sales_bukcet|      false|
+---------+------------+-----------+



In [33]:
sales_bucket = spark.read.table("sales_bukcet")

In [34]:
city_bucket = spark.read.table("city_bucket")

In [35]:
df_joined_bucket = sales_bucket.join(city_bucket, on= sales_bucket.city_id==city_bucket.city_id, how = "left_outer")

In [36]:
df_joined_bucket.write.format("noop").mode("overwrite").save()

                                                                                

In [38]:
df_joined_bucket.explain()
# noShuffle involved here
# 4 tasks were ran parallel as we've 4 buckets here
# For ex. Task 0, reads total 2026174 which includes records from both sales and city tables from 
# 0th buckets 

# Key Note: 
# Data spilled on Disk memory and there is scope for optimsation

== Physical Plan ==
*(3) SortMergeJoin [city_id#490], [city_id#497], LeftOuter
:- *(1) Sort [city_id#490 ASC NULLS FIRST], false, 0
:  +- FileScan csv spark_catalog.default.sales_bukcet[transacted_at#485,trx_id#486,retailer_id#487,description#488,amount#489,city_id#490] Batched: false, Bucketed: true, DataFilters: [], Format: CSV, Location: InMemoryFileIndex(1 paths)[file:/Users/vrushabh.deokar/pysparkBasics/spark-warehouse/data/input/d..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct<transacted_at:string,trx_id:string,retailer_id:string,description:string,amount:double,cit..., SelectedBucketsCount: 4 out of 4
+- *(2) Sort [city_id#497 ASC NULLS FIRST], false, 0
   +- *(2) Filter isnotnull(city_id#497)
      +- FileScan csv spark_catalog.default.city_bucket[city_id#497,city#498,state#499,state_abv#500,country#501] Batched: false, Bucketed: true, DataFilters: [isnotnull(city_id#497)], Format: CSV, Location: InMemoryFileIndex(1 paths)[file:/Users/vrushabh.deokar/pysparkBas

In [1]:
# Join Selection Strategies in PySpark
# SORT_MERGE, SHUFFLE_JOIN and BROADCAST_HASH_JOIN

# If join column is not same as bucketing column and same bucket size - data will shuffle
# If join column same as bucketing column and one table bucket - data will shuffle on smaller table
# If join column same as bucketing column and different bucket size - data will shuffle on smaller table
# If join column same as bucketing column and same bucket size - No shuffle (Fast op)