In [1]:
import spark_env

spark = spark_env.create_spark_session('joins')

In [10]:
from pyspark.sql.functions import *
from pyspark.sql.types import *

- In Spark, joins are not just about matching keys — they’re executed using different physical strategies under the hood based on data size, memory, and join type.
- The most efficient join is the **Broadcast Hash Join**, which works by broadcasting the smaller DataFrame to all executor nodes. This avoids shuffling altogether and allows each executor to perform the join locally. It’s ideal when one side of the join is small enough to fit into memory (typically under 10MB by default, configurable via spark.sql.autoBroadcastJoinThreshold). You can also explicitly use it with broadcast(df_small) when joining with a large DataFrame.
- When the data is too large to broadcast but still one side is significantly smaller, Spark may use a **Shuffle Hash Join**. This strategy involves shuffling both datasets based on the join key, and then building a hash table on one side for lookup. It's faster than Sort-Merge Join when memory is sufficient and the key distribution isn’t skewed. However, it's more sensitive to memory pressure and hash collisions and can result in spilling to disk.
- For large-scale joins where both DataFrames are big and neither can be broadcasted, Spark uses the **Sort-Merge Join**. This strategy sorts both sides of the data on the join key and then merges them. Though it involves full shuffling and sorting, it’s the most stable and scalable approach for large joins. It’s the default choice for equi-joins when data size exceeds the broadcast threshold and there's no major skew.

In [3]:
spark.conf.set('spark.sql.autoBroadcastJoinThreshold',-1)
spark.conf.set('spark.sql.adaptive.enabled',False)

In [7]:
data1 = [
    (1,"Alice"),
    (2,"Bob"),
    (3,"Charlie"),
    (4,"David"),
    (5,"Eva")
]
df1 = spark.createDataFrame(data1, ["id","name"])

data2 = [
    (1,50000),
    (2,60000),
    (3,70000),
    (6,80000)
]
df2 = spark.createDataFrame(data2, ["id","salary"])

In [8]:
# Left Join
df_left_join = df1.join(df2,df1['id'] == df2['id'],'left')

In [9]:
df_left_join.show()

+---+-------+----+------+
| id|   name|  id|salary|
+---+-------+----+------+
|  5|    Eva|NULL|  NULL|
|  1|  Alice|   1| 50000|
|  3|Charlie|   3| 70000|
|  2|    Bob|   2| 60000|
|  4|  David|NULL|  NULL|
+---+-------+----+------+



## Driver Memory Management

- The Driver in Spark is the master process that coordinates all tasks.
- It holds metadata, task scheduling information, DAGs and more.
- Driver Memory -> Memory allocated to the Driver process when the spark job runs
- If Driver runs out of memory -> job fails with OutOfMemoryError
- The Driver memory is divided into two memory types
1. **JVM Heap Memory**: This is the main memory allocated to the Java Virtual Machine where Spark’s core data structures live (e.g., RDDs, DataFrames, metadata). It is used for Task scheduling, Query planning and optimization, Caching metadata, broadcast variables, etc. JVM heap memory is where Spark runs its logic.
2. **OverHead Memory**: This is extra memory reserved outside the JVM heap. It is used for Native memory (like PySpark or Pandas UDFs), Thread stacks, Internal buffers, Memory management by the OS. Overhead memory prevents out-of-memory errors during native or I/O-heavy operations. If you’re using PySpark or UDFs, increasing overhead memory is often necessary. This is max(10% of JVM Heap Memory, 384 MB)

#### Driver Out of Memory
When the size of the output from executors goes out of the range of the driver memory we get this error. We can mitigate this by avoiding heavy functions that return too much data like df.collect()

## Executor Memory Management

- The Executor Memory management is broadly divided into four categories:
1. JVM Heap Memory: This memory is further broken down into three memories  
   a. **Reserved Memory** [300 MB]: Is allocated 300 MB by default  
   b. **User Memory** [0.4*(Total memory - Reserved Memory)]: This stores all the user defined functions.  
   c. **Spark Memory Pool** [0.6*(Total memory - Reserved Memory)]:
   - This memory stores all the cache and transformations that are required.
   - 50% memory is used for caching and storing (Long Term Memory). Also called as **Storage Memory**
   - 50% memory is used for transformations (Short Term Memory). Also called as **Executor Memory**
   - However, this partition can be changed. This is called as allocation and borrowing
   - Executor Memory can eliminate storage memory using LRU method but storage memory cannot eliminate the executor memory  
3. Off-Heap Memory: Managed by the user and the default is zero
4. Overhead Memory: This is extra memory outside the JVM heap memory -> max(10% of executor memory, 384 MB)
5. Pyspark Momory: Used rarely and default is zero


#### Executor Out of Memory
This is caused because of one of the following:
- **Large data per task (partition too big):** A single task processes more data than the executor can hold.
- **Skewed data in joins or aggregations:** One key has too much data → some executors do all the work and crash.
- **Improper caching/persisting:** Caching a large DataFrame without enough memory (especially with MEMORY_ONLY) leads to eviction or OOM.
- **Use of wide transformations:** Operations like groupByKey, join, sort cause shuffle and large intermediate data in memory.
- **Calling collect() on large DataFrames:** Tries to bring all data to the driver or executor memory → instant crash if it can't fit.
- **Heavy PySpark UDFs or Pandas UDFs:** Native code or Python logic consumes off-heap memory → hits memoryOverhead limit.
- **Too many tasks per executor:** Multiple tasks run concurrently, each using memory → total usage exceeds limit.
- **Improper configuration:** --executor-memory or spark.executor.memoryOverhead is set too low.
- **GC overhead without error:** JVM spends too much time in garbage collection due to memory pressure (near-OOM symptom).


#### Salting

Salting is a technique used in Spark to handle data skew, which occurs when one or a few keys in a join or aggregation operation have significantly more data than others. This imbalance causes Spark to assign a disproportionately large amount of data to a single task or executor, leading to performance bottlenecks or even out-of-memory errors. Salting mitigates this by artificially distributing skewed keys across multiple partitions. It works by appending a random "salt" value (like a number) to the skewed key, effectively transforming one heavy key into multiple lighter keys (e.g., "India" becomes "India_1", "India_2", etc.). The other dataset (in the case of a join) is also modified to match this salted structure. After performing the join or aggregation, the results can optionally be de-salted or recombined. This approach helps achieve better parallelism and load balancing during execution.

## Caching

In [13]:
data1 = [
    (1,"Alice"),
    (2,"Bob"),
    (3,"Charlie"),
    (4,"David"),
    (5,"Eva")
]
df1 = spark.createDataFrame(data1, ["id","name"])
df3 = spark.createDataFrame(data1, ["id","name"])

In [14]:
df1.show()

+---+-------+
| id|   name|
+---+-------+
|  1|  Alice|
|  2|    Bob|
|  3|Charlie|
|  4|  David|
|  5|    Eva|
+---+-------+



In [15]:
df1 = df1.withColumn('flag',lit('Yes'))

In [16]:
df1.show()

+---+-------+----+
| id|   name|flag|
+---+-------+----+
|  1|  Alice| Yes|
|  2|    Bob| Yes|
|  3|Charlie| Yes|
|  4|  David| Yes|
|  5|    Eva| Yes|
+---+-------+----+



#### Without Caching

Here we are just creating df2 using df1. What happens here is that once the dataframe df2 is called for creation the dataframe df1 is removed from the memory. Then how does df2 get created if df1 is not in memory? This is done by using the DAG for df1. However, since the df1 is created this is not efficient.

In [17]:
df2 = df1.filter(col('id') == 1)

In [18]:
df2.show()

+---+-----+----+
| id| name|flag|
+---+-----+----+
|  1|Alice| Yes|
+---+-----+----+



In [19]:
df2.explain()

== Physical Plan ==
*(1) Project [id#56L, name#57, Yes AS flag#73]
+- *(1) Filter (isnotnull(id#56L) AND (id#56L = 1))
   +- *(1) Scan ExistingRDD[id#56L,name#57]




#### With Caching

In this we will recreate the dataframe df3 same as df1 but now we will cache the df3 and then create df4 using the same operations as that of df2. Then we will explain df4 and see the difference between the df2 and df4 creation. 

In [20]:
df3.cache()

DataFrame[id: bigint, name: string]

In [21]:
df4 = df3.filter(col('id') == 1)

In [22]:
df4.show()

+---+-----+
| id| name|
+---+-----+
|  1|Alice|
+---+-----+



In [23]:
df4.explain()

== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- Filter (isnotnull(id#60L) AND (id#60L = 1))
   +- InMemoryTableScan [id#60L, name#61], [isnotnull(id#60L), (id#60L = 1)]
         +- InMemoryRelation [id#60L, name#61], StorageLevel(disk, memory, deserialized, 1 replicas)
               +- *(1) Scan ExistingRDD[id#60L,name#61]




## Storage Levels for Cache

1. **MEMEORY_AND_DISK**
- Tries to store the data in memory first.
- If memory is not enough, spill the rest to the disk.
- df.cache() = df.persist(StorageLevel.MEMORY_AND_DISK)

2. **MEMORY_ONLY**
- Data is stored as RAM in deserialized Java objects.
- If not enough memory, recompute the partitions when needed.

3. **DISK_ONLY**
- Stores data on the disk
- Slowest option.

4. **OFF_HEAP**
- Uses off-heap memory(outside JVM heap)
- Must be enabled with spark.memory.offHeap.enabled=true

In [24]:
df3.unpersist()

DataFrame[id: bigint, name: string]

### Partitioning

Partitioning in Spark refers to how data is divided into smaller chunks (partitions) for parallel processing across the cluster. This can be done in memory using functions like .repartition() or .coalesce() to control the number and distribution of tasks across executors. It also applies to on-disk storage — when writing data with .write.partitionBy(column), Spark creates folder structures based on that column (e.g., /year=2023/). Proper partitioning improves parallelism and query performance by reducing data shuffling and task imbalance.

### Pruning
Pruning, specifically column pruning, is a logical optimization where Spark reads only the columns requested in a query instead of loading the entire dataset. This is especially effective with columnar storage formats like Parquet and ORC. For example, when you run df.select("name"), Spark will only load the name column from disk, skipping the rest. This reduces I/O and memory usage, speeding up query execution.

### Partition Pruning
Partition pruning is a physical read-time optimization where Spark avoids scanning irrelevant data partitions on disk based on query filters. If a dataset is partitioned by country and your query includes filter(country = 'US'), Spark will only scan the /country=US/ folder rather than reading all partitions. This significantly improves performance. Partition pruning can be static (filter known at compile time) or dynamic (filter resolved at runtime, often in joins).

In [5]:
sample_data = [
    ("Alice",'HR',1000),
    ("Bob",'IT',2000),
    ("Charlie",'HR',1500),
    ("David",'Finance',2500),
    ("Eve",'IT',1800),
    ("Frank",'Finance',2200)
]

columns = ['name','department','salary']

df = spark.createDataFrame(sample_data, columns)

output_path = './Output/'
output_path_new = './OutputNew/'

In [4]:
df.write\
    .mode('overwrite')\
    .partitionBy('department')\
    .parquet(output_path)

In [6]:
df.write\
    .mode('overwrite')\
    .parquet(output_path_new)

In [14]:
df_from_part = spark.read.format('parquet')\
                            .load(output_path)\
                            .filter(col('department') == 'HR')                            

In [15]:
df_from_part.show()

+-------+------+----------+
|   name|salary|department|
+-------+------+----------+
|Charlie|  1500|        HR|
|  Alice|  1000|        HR|
+-------+------+----------+



In [16]:
df_without_part = spark.read.format('parquet')\
                            .load(output_path_new)\
                            .filter(col('department') == 'HR')

In [17]:
df_without_part.show()

+-------+----------+------+
|   name|department|salary|
+-------+----------+------+
|Charlie|        HR|  1500|
|  Alice|        HR|  1000|
+-------+----------+------+



### Impact of Partition Pruning.

When we perform without partition and try to filter the rows the output is
- Scan parquet  
**number of files read: 7**  
**scan time total (min, med, max ) -> 127 ms (15 ms, 16 ms, 26 ms )**  
metadata time: 0 ms  
**size of files read: 6.2 KiB**  
number of output rows: 2

While when we perform parititioning we get the output as
- Scan parquet  
**number of files read: 2**  
**scan time total (min, med, max ) -> 31 ms (15 ms, 16 ms, 16 ms )**  
dynamic partition pruning time: 0 ms  
metadata time: 13 ms  
**size of files read: 1476.0 B**  
number of output rows: 2  
number of partitions read: 1  

### Dynamic Partition Pruning

Dynamic Partition Pruning (DPP) is a performance optimization in Spark SQL that reduces the amount of data read during a join by dynamically filtering partitions at runtime instead of reading all partitions of a table.

🔹 How It Works:
When you run a query that joins a large partitioned table with another table, Spark doesn't always know in advance (at compile time) which partitions it needs. With Dynamic Partition Pruning, Spark waits until the broadcast or shuffle of the smaller table completes, extracts the relevant partition values, and then prunes the partitions of the larger table at execution time.

🔹 Example:

SELECT *
FROM large_sales s
JOIN region_filter r
ON s.region_id = r.region_id

If large_sales is partitioned by region_id, DPP ensures Spark only scans the matching partitions based on values from region_filter — instead of scanning all region_id partitions.

🔹 Why It Matters:
- Improves performance (reduces I/O and scan time)
- Especially useful when filter values are only known at runtime (e.g., in subqueries or joins)