# Exploring Joins in Spark

We will explore these topics in this notebook:
1. Natural vs Regular Joins Expressions
2. Filter Pushdown for Joins
3. Joining on Skewed Data
4. Range Joins Conditions

TODO: Find examples where you would bucket your joins

### Library Imports

In [1]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F

In [2]:
# Create a SparkSession. No need to create SparkContext
# You automatically get it as part of the SparkSession
spark = SparkSession.builder \
    .master("local") \
    .appName("Exploring Joins") \
    .config("spark.some.config.option", "some-value") \
    .getOrCreate()

# Case Study 1: Natural vs Regular Joins

Two types of joins:
1. `Natural Join`  
    A Natural Join is where 2 tables are joined on the basis of all common columns.      
    ie. `left.join(right, 'key')`


2. `Regular Join`  
    A Inner Join is where 2 tables are joined on the basis of common columns mentioned in the ON clause.
    ie. `left.join(right, left[lkey] == right[rkey])`

**Question:**
    Is `rename`ing a `column` then doing a `natural join` better than doing an `inner join`?

### Initial Datasets

In [3]:
df_1 = spark.createDataFrame(
    [
        (1, 1, 'a'), 
        (2, 1, 'b'), 
        (2, 2, 'c'), 
    ], ['id', 'data_id', 'val_1']
)

df_1.toPandas()

Unnamed: 0,id,data_id,val_1
0,1,1,a
1,2,1,b
2,2,2,c


In [5]:
df_2 = spark.createDataFrame(
    [
        (1, 1, 10), 
        (2, 2, 20), 
    ], ['shop_id', 'data_id', 'val_2']
)

df_2.toPandas()

Unnamed: 0,shop_id,data_id,val_2
0,1,1,10
1,2,2,20


## Option 1: Rename Key, then Join

In [6]:
df_3 = df_1.withColumnRenamed('id', 'shop_id')

df = df_3.join(df_2, 'shop_id')

df.toPandas()

Unnamed: 0,shop_id,data_id,val_1,data_id.1,val_2
0,1,1,a,1,10
1,2,1,b,2,20
2,2,2,c,2,20


In [7]:
df.explain()

== Physical Plan ==
*(5) Project [shop_id#12L, data_id#1L, val_1#2, data_id#7L, val_2#8L]
+- *(5) SortMergeJoin [shop_id#12L], [shop_id#6L], Inner
   :- *(2) Sort [shop_id#12L ASC NULLS FIRST], false, 0
   :  +- Exchange hashpartitioning(shop_id#12L, 200)
   :     +- *(1) Project [id#0L AS shop_id#12L, data_id#1L, val_1#2]
   :        +- *(1) Filter isnotnull(id#0L)
   :           +- Scan ExistingRDD[id#0L,data_id#1L,val_1#2]
   +- *(4) Sort [shop_id#6L ASC NULLS FIRST], false, 0
      +- Exchange hashpartitioning(shop_id#6L, 200)
         +- *(3) Filter isnotnull(shop_id#6L)
            +- Scan ExistingRDD[shop_id#6L,data_id#7L,val_2#8L]


## Option 2: Don't Rename, Regular Join, Drop Column

In [8]:
join_condition = df_1['id'] == df_2['shop_id']

df = df_1.join(df_2, join_condition).drop(df_1['id'])

df.toPandas()

Unnamed: 0,data_id,val_1,shop_id,data_id.1,val_2
0,1,a,1,1,10
1,1,b,2,2,20
2,2,c,2,2,20


In [9]:
df.explain()

== Physical Plan ==
*(5) Project [data_id#1L, val_1#2, shop_id#6L, data_id#7L, val_2#8L]
+- *(5) SortMergeJoin [id#0L], [shop_id#6L], Inner
   :- *(2) Sort [id#0L ASC NULLS FIRST], false, 0
   :  +- Exchange hashpartitioning(id#0L, 200)
   :     +- *(1) Filter isnotnull(id#0L)
   :        +- Scan ExistingRDD[id#0L,data_id#1L,val_1#2]
   +- *(4) Sort [shop_id#6L ASC NULLS FIRST], false, 0
      +- Exchange hashpartitioning(shop_id#6L, 200)
         +- *(3) Filter isnotnull(shop_id#6L)
            +- Scan ExistingRDD[shop_id#6L,data_id#7L,val_2#8L]


## TL;DR

Option #1
* Looks nicer and more elegant.
* Does it perform better though?

Option #2
* There is one less project as expected without the `withColumnRenamed`.

Is this better?

# Case Study 2: Filter Pushdown

`Filter pushdown` improves performance by reducing the amount of data shuffled during any dataframes transformations.

### Initial Datasets

In [11]:
df_1 = spark.createDataFrame(
    [
        (1, 1, 'a'), 
        (2, 1, 'b'), 
        (2, 2, 'c'), 
    ], ['shop_id', 'data_id', 'val_1']
)

df_1.toPandas()

Unnamed: 0,shop_id,data_id,val_1
0,1,1,a
1,2,1,b
2,2,2,c


In [12]:
df_2 = spark.createDataFrame(
    [
        (1, 1, 10), 
        (2, 2, 20), 
    ], ['shop_id', 'data_id', 'val_2']
)

df_2.toPandas()

Unnamed: 0,shop_id,data_id,val_2
0,1,1,10
1,2,2,20


## Option #1: Join data, then Filter

In [13]:
df_3 = df_1.join(df_2.drop('shop_id'), 'data_id').filter(F.col('shop_id') == 1)

df_3.toPandas()

Unnamed: 0,data_id,shop_id,val_1,val_2
0,1,1,a,10


In [14]:
df_3.explain()

== Physical Plan ==
*(5) Project [data_id#45L, shop_id#44L, val_1#46, val_2#52L]
+- *(5) SortMergeJoin [data_id#45L], [data_id#51L], Inner
   :- *(2) Sort [data_id#45L ASC NULLS FIRST], false, 0
   :  +- Exchange hashpartitioning(data_id#45L, 200)
   :     +- *(1) Filter ((isnotnull(shop_id#44L) && (shop_id#44L = 1)) && isnotnull(data_id#45L))
   :        +- Scan ExistingRDD[shop_id#44L,data_id#45L,val_1#46]
   +- *(4) Sort [data_id#51L ASC NULLS FIRST], false, 0
      +- Exchange hashpartitioning(data_id#51L, 200)
         +- *(3) Project [data_id#51L, val_2#52L]
            +- *(3) Filter isnotnull(data_id#51L)
               +- Scan ExistingRDD[shop_id#50L,data_id#51L,val_2#52L]


**What Happened:**

* We can see that the filter is after the join and not pushed down. 
* This means all of the data is brough to the join.
* Then the filter is done.

**Results:**

We bring more data to the join and shuffle, **this is bad**.

## Option #2: Join on Filter Key, Filter After

In [15]:
df_4 = df_1.join(df_2, ['shop_id', 'data_id']).filter(F.col('shop_id') == 1)

df_4.toPandas()

Unnamed: 0,shop_id,data_id,val_1,val_2
0,1,1,a,10


In [16]:
df_4.explain()

== Physical Plan ==
*(5) Project [shop_id#44L, data_id#45L, val_1#46, val_2#52L]
+- *(5) SortMergeJoin [shop_id#44L, data_id#45L], [shop_id#50L, data_id#51L], Inner
   :- *(2) Sort [shop_id#44L ASC NULLS FIRST, data_id#45L ASC NULLS FIRST], false, 0
   :  +- Exchange hashpartitioning(shop_id#44L, data_id#45L, 200)
   :     +- *(1) Filter ((isnotnull(shop_id#44L) && (shop_id#44L = 1)) && isnotnull(data_id#45L))
   :        +- Scan ExistingRDD[shop_id#44L,data_id#45L,val_1#46]
   +- *(4) Sort [shop_id#50L ASC NULLS FIRST, data_id#51L ASC NULLS FIRST], false, 0
      +- Exchange hashpartitioning(shop_id#50L, data_id#51L, 200)
         +- *(3) Filter (((shop_id#50L = 1) && isnotnull(data_id#51L)) && isnotnull(shop_id#50L))
            +- Scan ExistingRDD[shop_id#50L,data_id#51L,val_2#52L]


**What Happened:**
* The filter got pushed down.
* Less data is brought to the join and shuffle.

**Results:**

We bring less data to the join and shuffle, **this is good**.

## Option #3: Filter Left, then Join

In [17]:
df_5 = df_1.filter(F.col('shop_id') == 1).join(df_2.drop('shop_id'), 'data_id')

df_5.toPandas()

Unnamed: 0,data_id,shop_id,val_1,val_2
0,1,1,a,10


In [18]:
df_5.explain()

== Physical Plan ==
*(5) Project [data_id#45L, shop_id#44L, val_1#46, val_2#52L]
+- *(5) SortMergeJoin [data_id#45L], [data_id#51L], Inner
   :- *(2) Sort [data_id#45L ASC NULLS FIRST], false, 0
   :  +- Exchange hashpartitioning(data_id#45L, 200)
   :     +- *(1) Filter ((isnotnull(shop_id#44L) && (shop_id#44L = 1)) && isnotnull(data_id#45L))
   :        +- Scan ExistingRDD[shop_id#44L,data_id#45L,val_1#46]
   +- *(4) Sort [data_id#51L ASC NULLS FIRST], false, 0
      +- Exchange hashpartitioning(data_id#51L, 200)
         +- *(3) Project [data_id#51L, val_2#52L]
            +- *(3) Filter isnotnull(data_id#51L)
               +- Scan ExistingRDD[shop_id#50L,data_id#51L,val_2#52L]


**What Happened:**
* This is exactly the same as case 1.

**Results:**

We bring less data to the join and shuffle, **this is bad**.

## Option #4: Filter Both, then Join

In [19]:
df_6 = df_1.filter(F.col('shop_id') == 1).join(
    df_2.filter(F.col('shop_id') == 1).drop('shop_id'), 
    'data_id'
)

df_6.toPandas()

Unnamed: 0,data_id,shop_id,val_1,val_2
0,1,1,a,10


In [20]:
df_6.explain()

== Physical Plan ==
*(5) Project [data_id#45L, shop_id#44L, val_1#46, val_2#52L]
+- *(5) SortMergeJoin [data_id#45L], [data_id#51L], Inner
   :- *(2) Sort [data_id#45L ASC NULLS FIRST], false, 0
   :  +- Exchange hashpartitioning(data_id#45L, 200)
   :     +- *(1) Filter ((isnotnull(shop_id#44L) && (shop_id#44L = 1)) && isnotnull(data_id#45L))
   :        +- Scan ExistingRDD[shop_id#44L,data_id#45L,val_1#46]
   +- *(4) Sort [data_id#51L ASC NULLS FIRST], false, 0
      +- Exchange hashpartitioning(data_id#51L, 200)
         +- *(3) Project [data_id#51L, val_2#52L]
            +- *(3) Filter ((isnotnull(shop_id#50L) && (shop_id#50L = 1)) && isnotnull(data_id#51L))
               +- Scan ExistingRDD[shop_id#50L,data_id#51L,val_2#52L]


## TL;DR

* We should always try to push the filter down as much as possible. 
* This means that there will be less data being shuffled and joined during the join. 
* This can be achieved with join in case #2 or #4.

**Option #2** (Good)
* When we `join`ed on `filter`ed on the key `shop_id` this caused a `filter-pushdown` which is good.
* But this made us `sort` on 2 keys.

**Option #4** (Better)
* When we pre `filter` the `join`ing datasets, this caused a `filter-pushdown` which is good.
* We only `join` on one key as well, which is good as we only sort on 1 key.

# Case Study 3: Joins on Skewed Data

A `skewed dataset` is defined by a dataset that has a class imbalance, this leads to poor or uncompletable spark jobs often getting `OOM` (out of memory) errors.

When performing a `join` onto a `skewed dataset` it means that there exists a class imbalance on the `key`s on which the join is performed on. This results in a majority of the data falls onto one partition, which will take longer to complete than the other partitions.

Some examples of this are:
1. The keys consist mainly of `null` values which fall onto a single partition.
2. There is a subset of keys that makeup the majority percentage of the keys which fall onto a single partition.

## Situation 1: Null Keys

Inital Datasets

In [21]:
customers = spark.createDataFrame([
    (1, None), 
    (2, None), 
    (3, 1),
], ["id", "card_id"])

customers.toPandas()

Unnamed: 0,id,card_id
0,1,
1,2,
2,3,1.0


In [22]:
cards = spark.createDataFrame([
    (1, "john", "doe", 21), 
    (2, "rick", "roll", 10), 
    (3, "bob", "brown", 2)
], ["card_id", "first_name", "last_name", "age"])

cards.toPandas()

Unnamed: 0,card_id,first_name,last_name,age
0,1,john,doe,21
1,2,rick,roll,10
2,3,bob,brown,2


### Option #1: Join Regularly

In [23]:
df = customers.join(cards, "card_id", "left")

df.toPandas()

Unnamed: 0,card_id,id,first_name,last_name,age
0,,1,,,
1,,2,,,
2,1.0,3,john,doe,21.0


In [24]:
df.explain()

== Physical Plan ==
*(3) Project [card_id#84L, id#83L, first_name#88, last_name#89, age#90L]
+- SortMergeJoin [card_id#84L], [card_id#87L], LeftOuter
   :- *(1) Sort [card_id#84L ASC NULLS FIRST], false, 0
   :  +- Exchange hashpartitioning(card_id#84L, 200)
   :     +- Scan ExistingRDD[id#83L,card_id#84L]
   +- *(2) Sort [card_id#87L ASC NULLS FIRST], false, 0
      +- Exchange hashpartitioning(card_id#87L, 200)
         +- Scan ExistingRDD[card_id#87L,first_name#88,last_name#89,age#90L]


**What Happened**:
* Rows that didn't join up were brought to the join.
* They get `Null` values for the right side columns.

**Results**:
* We brought more data to the join than we had to.

### Option #2: Filter Null Keys First, then Join, then Union

In [27]:
def null_skew_helper(left, right, key):
    """
    Steps:
        1. Filter out the null rows.
        2. Create the columns you would get from the join.
        3. Join the tables.
        4. Union the null rows to joined table.
    """
    df1 = left.where(F.col(key).isNull())
    for f in right.schema.fields:
            df1 = df1.withColumn(f.name, F.lit(None).cast(f.dataType))
    
    df2 = left.where(F.col(key).isNotNull())
    df2 = df2.join(right, key, "left")
    
    return df1.union(df2.select(df1.columns))
    
    
df = null_skew_helper(customers, cards, "card_id")

df.toPandas()

Unnamed: 0,id,card_id,first_name,last_name,age
0,1,,,,
1,2,,,,
2,3,1.0,john,doe,21.0


In [28]:
df.explain()

== Physical Plan ==
Union
:- *(1) Project [id#83L, null AS card_id#101L, null AS first_name#104, null AS last_name#108, null AS age#113L]
:  +- *(1) Filter isnull(card_id#84L)
:     +- Scan ExistingRDD[id#83L,card_id#84L]
+- *(5) Project [id#83L, card_id#84L, first_name#88, last_name#89, age#90L]
   +- SortMergeJoin [card_id#84L], [card_id#87L], LeftOuter
      :- *(3) Sort [card_id#84L ASC NULLS FIRST], false, 0
      :  +- Exchange hashpartitioning(card_id#84L, 200)
      :     +- *(2) Filter isnotnull(card_id#84L)
      :        +- Scan ExistingRDD[id#83L,card_id#84L]
      +- *(4) Sort [card_id#87L ASC NULLS FIRST], false, 0
         +- Exchange hashpartitioning(card_id#87L, 200)
            +- Scan ExistingRDD[card_id#87L,first_name#88,last_name#89,age#90L]


**What Happened**:
* We filtered all the rows out before the join.
* We did the join with less data.
* We read the table again and got the null rows.
* Unioned with the joined results.

**Results**:
* We brought less data to the join.
* We read the data twice.

### Option #3: Cache Table, Filter Null Keys First, then Join, then Union

In [29]:
def null_skew_helper(left, right, key):
    """
    Steps:
        1. Filter out the null rows.
        2. Create the columns you would get from the join.
        3. Join the tables.
        4. Union the null rows to joined table.
    """
    left = left.cache()
    
    df1 = left.where(F.col(key).isNull())
    for f in right.schema.fields:
            df1 = df1.withColumn(f.name, F.lit(None).cast(f.dataType))
    
    df2 = left.where(F.col(key).isNotNull())
    df2 = df2.join(right, key, "left")
    
    return df1.union(df2.select(df1.columns))
    
    
df = null_skew_helper(customers, cards, "card_id")

df.toPandas()

Unnamed: 0,id,card_id,first_name,last_name,age
0,1,,,,
1,2,,,,
2,3,1.0,john,doe,21.0


In [30]:
df.explain()

== Physical Plan ==
Union
:- *(1) Project [id#83L, null AS card_id#146L, null AS first_name#149, null AS last_name#153, null AS age#158L]
:  +- *(1) Filter isnull(card_id#84L)
:     +- *(1) InMemoryTableScan [card_id#84L, id#83L], [isnull(card_id#84L)]
:           +- InMemoryRelation [id#83L, card_id#84L], true, 10000, StorageLevel(disk, memory, deserialized, 1 replicas)
:                 +- Scan ExistingRDD[id#83L,card_id#84L]
+- *(5) Project [id#83L, card_id#84L, first_name#88, last_name#89, age#90L]
   +- SortMergeJoin [card_id#84L], [card_id#87L], LeftOuter
      :- *(3) Sort [card_id#84L ASC NULLS FIRST], false, 0
      :  +- Exchange hashpartitioning(card_id#84L, 200)
      :     +- *(2) Filter isnotnull(card_id#84L)
      :        +- *(2) InMemoryTableScan [id#83L, card_id#84L], [isnotnull(card_id#84L)]
      :              +- InMemoryRelation [id#83L, card_id#84L], true, 10000, StorageLevel(disk, memory, deserialized, 1 replicas)
      :                    +- Scan ExistingRDD[i

**What Happened**:
* Similar to option #2, but we did a `InMemoryTableScan` instead of a read of the data.

**Results**:
* We brought less data to the join.
* We did 1 less read.
* But we used more memory.

## TL;DR

As always there is pros and cons.

Pros:
* Ideally you want to bring less data to a join. 
* This is unneeded data and in most cases causes a spark job to fail. 
* This is due to the fact that all the null key rows will go onto one partition.

Cons:
* There will either be one extra read of data or more memory used.

All to say:
1. It's definitely better to bring less data to a join, so do a filter of the null keys before the join.
2. This will result in an extra read of data or memory usage.
3. Decide if you can afford the extra read vs memory usage.

## Case Study 4: Range Join Conditions

> A naive approach (just specifying this as the range condition) would result in a full cartesian product and a filter that enforces the condition (tested using Spark 2.0). This has a horrible effect on performance, especially if DataFrames are more than a few hundred thousands records.

source: http://zachmoshe.com/2016/09/26/efficient-range-joins-with-spark.html

> The source of the problem is pretty simple. When you execute join and join condition is not equality based the only thing that Spark can do right now is expand it to Cartesian product followed by filter what is pretty much what happens inside `BroadcastNestedLoopJoin`

source: https://stackoverflow.com/questions/37953830/spark-sql-performance-join-on-value-between-min-and-max?answertab=active#tab-top

### Initial Dataset

In [19]:
geo_loc_table = sqlc.createDataFrame([
    (1, 10, "foo"), 
    (11, 36, "bar"), 
    (37, 59, "baz"),
], ["ipstart", "ipend", "loc"])

geo_loc_table.toPandas()

Unnamed: 0,ipstart,ipend,loc
0,1,10,foo
1,11,36,bar
2,37,59,baz


In [20]:
records_table = sqlc.createDataFrame([
    (1, 11), 
    (2, 38), 
    (3, 50),
],["id", "inet"])

records_table.toPandas()

Unnamed: 0,id,inet
0,1,11
1,2,38
2,3,50


### Option #1

In [21]:
join_condition = [
    records_table['inet'] >= geo_loc_table['ipstart'],
    records_table['inet'] <= geo_loc_table['ipend'],
]

df = records_table.join(geo_loc_table, join_condition, "left")

df.toPandas()

Unnamed: 0,id,inet,ipstart,ipend,loc
0,1,11,11,36,bar
1,2,38,37,59,baz
2,3,50,37,59,baz


In [22]:
df.explain()

== Physical Plan ==
BroadcastNestedLoopJoin BuildRight, LeftOuter, ((inet#90L >= ipstart#83L) && (inet#90L <= ipend#84L))
:- Scan ExistingRDD[id#89L,inet#90L]
+- BroadcastExchange IdentityBroadcastMode
   +- Scan ExistingRDD[ipstart#83L,ipend#84L,loc#85]


### Option #2

In [23]:
from bisect import bisect_right
from pyspark.sql.functions import udf
from pyspark.sql.types import LongType

geo_start_bd = sc.broadcast(map(lambda x: x.ipstart, geo_loc_table
    .select("ipstart")
    .orderBy("ipstart")
    .collect()
))

def find_le(x):
    'Find rightmost value less than or equal to x'
    i = bisect_right(geo_start_bd.value, x)
    if i:
        return geo_start_bd.value[i-1]
    return None

records_table_with_ipstart = records_table.withColumn(
    "ipstart", udf(find_le, LongType())("inet")
)

df = records_table_with_ipstart.join(geo_loc_table, ["ipstart"], "left")

df.toPandas()

Unnamed: 0,ipstart,id,inet,ipend,loc
0,37,2,38,59,baz
1,37,3,50,59,baz
2,11,1,11,36,bar


In [24]:
df.explain()

== Physical Plan ==
*(4) Project [ipstart#110L, id#89L, inet#90L, ipend#84L, loc#85]
+- SortMergeJoin [ipstart#110L], [ipstart#83L], LeftOuter
   :- *(2) Sort [ipstart#110L ASC NULLS FIRST], false, 0
   :  +- Exchange hashpartitioning(ipstart#110L, 200)
   :     +- *(1) Project [id#89L, inet#90L, pythonUDF0#119L AS ipstart#110L]
   :        +- BatchEvalPython [find_le(inet#90L)], [id#89L, inet#90L, pythonUDF0#119L]
   :           +- Scan ExistingRDD[id#89L,inet#90L]
   +- *(3) Sort [ipstart#83L ASC NULLS FIRST], false, 0
      +- Exchange hashpartitioning(ipstart#83L, 200)
         +- Scan ExistingRDD[ipstart#83L,ipend#84L,loc#85]
