# 02 Optimize Joins

In [0]:
from pyspark.sql.functions import *

In [0]:
spark.conf.set("spark.sql.adaptive.enabled", True)

In [0]:
spark.conf.get("spark.sql.adaptive.enabled")

Out[8]: 'true'

In [0]:
# Big DataFrame
df_transactions = spark.createDataFrame([
    (1, "US", 100),
    (2, "IN", 200),
    (3, "UK", 150),
    (4, "US", 80),
], ["id", "country_code", "amount"])

# Small DataFrame
df_countries = spark.createDataFrame([
    ("US", "United States"),
    ("IN", "India"),
    ("UK", "United Kingdom"),
], ["country_code", "country_name"])

In [0]:
df_transactions.display()

id,country_code,amount
1,US,100
2,IN,200
3,UK,150
4,US,80


In [0]:
df_countries.display()

country_code,country_name
US,United States
IN,India
UK,United Kingdom


Let us begin applying a usual join:

In [0]:
df_join = (df_transactions.join(
                      df_countries
                      , df_transactions['country_code'] == df_countries['country_code']
                      , how='inner' 
          )
)

In [0]:
df_join.display()

id,country_code,amount,country_code.1,country_name
2,IN,200,IN,India
3,UK,150,UK,United Kingdom
1,US,100,US,United States
4,US,80,US,United States


If you check for the DAG specifications, you will notice:

* A Sortmergejoin step will be there. What does it do? 

  * Sorting:

    Before the actual join operation, both datasets are sorted based on the join keys. This sorting is necessary because the merge operation requires the data to be in order.



  * Merging:

    After sorting, the datasets are merged by iterating through the sorted data. This is efficient because it allows the join to be performed in a single pass through the data.

* A filter step for each of the dataframes that will remove nulls from the column key

* Exchange: the step where the shuffling is performed. This is the crucial step because here is where you'll see that **200 partitions are performed by default when it comes to joins** even though the number of records is quite small

## Optimize Joins By Using Broadcast Function

In [0]:
df_join_opt = (df_transactions.join(
                                      broadcast(df_countries)
                                      ,df_transactions['country_code']==df_countries['country_code'],"inner")
              )

In [0]:
df_join_opt.display()

id,country_code,amount,country_code.1,country_name
1,US,100,US,United States
2,IN,200,IN,India
3,UK,150,UK,United Kingdom
4,US,80,US,United States


If you see the DAG, you'll notice there is no shuffle applied this time because the smaller table was broadcasted across all the worker nodes

**IMPORTANT** Broadcast must be used only on tables with a size between 5 to 10 MB. Otherwise, there might be performance problems