## APACHE SPARK / PYSPARK 

### What is Spark ?

Spark is in-memory distributed parallel processing engine with fault tolerant nature

### Architecture of Spark

<img src="https://aws.github.io/aws-emr-best-practices/assets/images/spark-bp-4-0ac9c7cc0bb2cff70ac01efa69455604.png" width="600"  height="400">

**Driver:** 
1. Driver is heart of spark application
2. Driver maintains the information of Worker node and it's executors
3. Driver node analyse, distrubute and schedules the tasks to worker node
4. This is the process the runs the main() function of spark application and that creates the SparkSession.

-----

**Executor:**
1. It is a process that lunches on worker node in a spark applicaiton, and it runs the tasks and keeps the data in memory and disk.
2. It responds to the driver with execution status

------

**cluster manager:**
1. Cluster manager is responsible for allocating the resources to worker nodes such as number of executors,cores and memory.

-------

**Partitions:**
1. Partitions are nothing but the logical divisons of the data called partitions
2. Each partition of the data will be processed by a signle task/core in Executor
3. Partitions are important interms of maintain the parellelism and optimize the performance. 

----

**Transformations:**
1. Transformation is a process that accepts an RDD/DataFrame as an input and applying the logics on it and returns a RDD/Dataframe as an output
2. This follows the lazy evaluation, that means the transformations will not be triggered immidetaly instead that waits until the  action is called.
3. Lazy evalution is important to optimize the execution plan of spak.

Example: **Select, Where, Group by, Order By, Limit**

There are 2 types of transformations are available in spack ecosystem 
1. **Narrow Transformations:** In the narrow transformations, One input partitions produces only one output partition
2. **Wide Transformations:** In the Wide transformations, multiple input partitions are needed to produces one output partitons, this step requires the shuffling the data.


Narrow Transformation example: Select, Where, map() \
Wide Transformation example: Join, Group By, Order By, groupByKey(), reduceByKey(), sortByKey(), join(), cogroup()

-----


**Actions:**
1. Action is a process that triggers the logical plan of tranformations and this will either returns the data back to customer or write the data to output data system.

Example: Collect(), head(), tail(), show(), take(), Save()

#### RDD VS DATAFRAME VS DATASET 

<img src="https://media.licdn.com/dms/image/v2/C4D12AQHqtDQj79malQ/article-cover_image-shrink_600_2000/article-cover_image-shrink_600_2000/0/1520122178712?e=2147483647&v=beta&t=gluvrQfDOKCjDbCHgBXi_ZMUY8H6oAJ1m0ghw-lNbY4" width="500"  height="500">

**RDD:**
1. The abbreviation of RDD is Resilient Distributed Datasets. It is an immutable collection of data objects that is distributed across the cluster and processed in parallel with fault-tolerant capabilities. 
2. It is a low-level API.
3. It does not provide schema enforcement or optimization techniques.
4. It is best suited for unstructured data and low-level transformations. 

**DataFrames:**
1. DataFrames are distributed collections of data organized into rows and columns, similar to a table in a relational database.
2.	They are conceptually equivalent to relational tables.
3.	DataFrames provide a high-level API for data manipulation.
4.	They are not type-safe, meaning if you try to access a column that doesn’t exist, the error will not be caught at compile time—it will be detected only at runtime.
5.	DataFrames use the Catalyst Optimizer for query optimization and code generation.
6.	Schema enforcement is supported, but only at runtime.
7.	They support multiple programming languages, including Python, Scala, Java, and R.
8.	Ideal for working with structured and semi-structured data.

**Dataset:**
A Dataset is a distributed collection of data that combines the benefits of strong typing, compile-time type safety, and object-oriented programming. It can be thought of as a typed version of a DataFrame, where each row is represented as an object of a specific type, defined using a case class in Scala or a Java bean in Java.

#### How spark works on partitions

<img src="https://images.squarespace-cdn.com/content/v1/58a3db0903596e0fdbb27391/1502484724596-T4IEBESUF806VN5CZOEZ/image-asset.png" width="500"  height="500">

1. When spark reads the data, then it automatically divides the data into partitions, often based on the data source i.e., Files in HDFS, or based on the User defined configurations.
2. Each partition is a logical dataset of the data contains the subset of overall data
3. Number of partitions can be controlled by User, the recommended to have the number of partitions that is a multiple of the number of cores in the cluster.
4. Each partition is processed by a Task, which is a smallest unit of Spark Job.
5. Spark assigns one task per partition, this enables spark to perform the parallel processing.
6. When Spark applies transformations (like map, filter, join) to the data, it operates on partitions.
7. Narrow transformations (like map, filter) can be applied within a partition without needing data from other partitions.
8. Wide transformations (like groupByKey, join) might require data from multiple partitions to be shuffled and redistributed, potentially leading to a performance bottleneck

### SPARK EXECUTION PLAN

<img src="https://miro.medium.com/v2/resize:fit:1400/format:webp/1*e4A-H9kJwnmbr9ivkBSazQ.png" width="600" height="500">

What Happens When You Submit a Spark Job (Spark 3 Flow)

1. 🧑‍💻 User submits a job \
Code: spark-submit or runs a notebook/cell \
Includes transformations (map, filter) and actions (collect, count)


3. 🧠 Unresolved Logical Plan is created \
Spark parses your code and builds a "unresolved logical plan" where relations and column names aren't specifically resolved
This is a blueprint of what you want to do, not how to do it yet

4. 🧠 Analyzed Logical Plan is created \
Unresolved logical plan then uses the Metadata catalog and validates the data structures, schema and types, and if everything goes well then this plan marked as "analyzed logical plan"

5. 🧹 Optimized Logical Plan \
Spark uses Catalyst Optimizer to improve the performance: \
Reorder operations\
Push down filters\
projection pruning\
Simplify expressions\
The result: an Optimized Logical Plan

Also **Lineage (we can get this using Explain or toDebugString** \
   **Predicate pushdown:** Moving the filter condistions [predicates] tp closer to data sources, it means filtering logic as early as possible during data retrieval or processing, rather than loading the entire dataset into memory and then filtering it) \
   
   **projection pruning:** Removing the unnessary columns while processing the query.

5. ⚙️ Physical Plan is created \
Spark translates the optimized Logical plan into a Physical Plan\
Now it decides how to execute:\
By considering the data partitions\
data shuffling\
task distributions accross the nodes\
sort-merge join vs broadcast join

6. ✅ Best Physical Plan is selected\
This Physical plans runs against cost based models,Which basically generate costs of each physical plans\
Spark compares multiple execution strategies and chooses the most efficient one that costs less\

7. 📦 DAG (Directed Acyclic Graph) is built\
The physical plan is converted into a DAG of stages and tasks (Job, stages, and tasks)\
Each Stage = group of tasks that can be run together (usually after shuffle)\

8. 🎯 Job is submitted to the cluster\
The Driver Program sends the DAG to the Cluster Manager (like YARN/K8s/Standalone)\

9. 🧩 Cluster Manager allocates Executors\
Executors are launched on worker nodes\
Each Executor gets a portion of your data to process\

10. 🏃‍♂️ Tasks are scheduled and run\
Spark breaks down the stages into Tasks (smallest units of work)\
Each task runs on an executor and processes a partition of data\

11. 📤 Results are collected or saved\
If action is collect() → data is returned to the driver\
If write() → data is saved to S3, HDFS, BigQuery, etc.

### DAG (Directed Acyclic Graph)

<img src="https://techvidvan.com/tutorials/wp-content/uploads/2019/11/DAG-Visualisation-01.jpg" width="600" height="600">

**Dag:** 
1. Dag represents the "Optimized logical plan" in the spark execution, it consists of transformations and actions.
2. DagScheduler is responsible for converting these transformations and actions into DAG of Jobs and Job consists of stages and tasks, which can be executed parallel in the cluster
3. **Stage:** It is group of tasks called a stage, and these tasks can be executed parallel. There are 2 types of stages are available in Dag i.e., Shuffle stage and non-shuffle stages.
4. **Tasks:** Task is a single unit of work that will be executed on single partition in spark


## Introduction to SparkSession and Coding 

In [63]:
# Creation of SparkSession
try:
    spark.stop()
except:
    pass

from pyspark.sql import SparkSession
from pyspark.sql.functions import lit
from pyspark.sql import functions as f

spark = (SparkSession
         .builder
         .appName("test")
         .config("spark.executor.instances", 4)
         .config("spark.executor.cores", 4)
         .getOrCreate())
spark.sparkContext.setLogLevel("ERROR")

In [64]:
#What is SparkSession 
spark

In [65]:
#Create scheam for user table
from pyspark.sql.types import StructType, StructField, IntegerType, StringType, _parse_datatype_string

user_schema = StructType([
StructField(name="id", dataType=IntegerType(), nullable=True),
StructField(name="age", dataType=IntegerType(), nullable=True),
StructField(name="gender", dataType=StringType(), nullable=True),
StructField(name="occupation", dataType=StringType(), nullable=True),
StructField(name="zipcode", dataType=IntegerType(), nullable=True)
]
)


# Create a schema using schema string using _parse_datatype_string
schema_str = "id int, age int, gender string, occupation string, zipcode int"
user_parsed_schema = _parse_datatype_string(schema_str)
print(user_parsed_schema)

StructType([StructField('id', IntegerType(), True), StructField('age', IntegerType(), True), StructField('gender', StringType(), True), StructField('occupation', StringType(), True), StructField('zipcode', IntegerType(), True)])


#### Creation of DataFrame

In [66]:
user_df = spark.read.csv(path="data/u.user", sep="|", schema=user_schema)

#### check Sample data

In [67]:
user_df.show(2)

+---+---+------+----------+-------+
| id|age|gender|occupation|zipcode|
+---+---+------+----------+-------+
|  1| 24|     M|technician|  85711|
|  2| 53|     F|     other|  94043|
+---+---+------+----------+-------+
only showing top 2 rows



#### View Schema

In [68]:
user_df.printSchema()

root
 |-- id: integer (nullable = true)
 |-- age: integer (nullable = true)
 |-- gender: string (nullable = true)
 |-- occupation: string (nullable = true)
 |-- zipcode: integer (nullable = true)



#### selecting the values on DF
There are different ways of selecting the columns in spark \
1.col and expr \
2.df["col_name"] \
3."col_name" \
4.df.col_name \
5.selectExpr("col_name1", "col_name2")

In [69]:
from pyspark.sql.functions import col, expr

user_df.select(col("id"), expr("id as new_id"), user_df["age"], "gender", user_df.occupation).show(5, truncate=False)

print("""
=============================
""")

user_df.selectExpr("id","id as new_id", "age", "gender", "occupation").show(2,truncate=False)

+---+------+---+------+----------+
|id |new_id|age|gender|occupation|
+---+------+---+------+----------+
|1  |1     |24 |M     |technician|
|2  |2     |53 |F     |other     |
|3  |3     |23 |M     |writer    |
|4  |4     |24 |M     |technician|
|5  |5     |33 |F     |other     |
+---+------+---+------+----------+
only showing top 5 rows



+---+------+---+------+----------+
|id |new_id|age|gender|occupation|
+---+------+---+------+----------+
|1  |1     |24 |M     |technician|
|2  |2     |53 |F     |other     |
+---+------+---+------+----------+
only showing top 2 rows



#### Filtering the values on DF

In [70]:
user_df.where(
    (col("id") > 1)
    & (user_df["age"].between(25, 35))
    & (col("gender") == "M")
    & (user_df.occupation.isin("other", "technician"))
    & (user_df.zipcode.cast(StringType()).ilike("4%"))
)\
.show(5, truncate=False)

print("""
=============================
""")


user_df.where("""
(id > 1)
AND (age BETWEEN 25 AND 35)
AND (gender = 'M')
AND (occupation in ('other', 'technician'))
AND (cast(zipcode as string) LIKE ('4%'))
""")\
.show(5, truncate=False)              

+---+---+------+----------+-------+
|id |age|gender|occupation|zipcode|
+---+---+------+----------+-------+
|44 |26 |M     |technician|46260  |
|178|26 |M     |other     |49512  |
|440|30 |M     |other     |48076  |
|689|25 |M     |other     |45439  |
+---+---+------+----------+-------+



+---+---+------+----------+-------+
|id |age|gender|occupation|zipcode|
+---+---+------+----------+-------+
|44 |26 |M     |technician|46260  |
|178|26 |M     |other     |49512  |
|440|30 |M     |other     |48076  |
|689|25 |M     |other     |45439  |
+---+---+------+----------+-------+



#### CAST Column 

In [71]:
user_df.select(col('id').cast('string').alias('string_id'), \
               col('age').cast(StringType()).alias('string_age'), \
               col('zipcode').cast('integer')
              ) \
.show(3,False)

+---------+----------+-------+
|string_id|string_age|zipcode|
+---------+----------+-------+
|1        |24        |85711  |
|2        |53        |94043  |
|3        |23        |32067  |
+---------+----------+-------+
only showing top 3 rows



#### Adding a new column/columns and CASE WHEN STATEMENT

In [72]:
from pyspark.sql import functions as f
user_df.withColumn("age_group", 
                  f.when(col("age") <= 20, "20's KID") \
                   .when(col("age").between(20,30), "90's KID") \
                   .when(col("age").between(30,40) , "80's KID") \
                   .otherwise("Old Person")
                  ).show(5, False)

print("""
=============================
""")

user_df.select("*", 
                  f.when(col("age") <= 20, "20's KID") \
                   .when(col("age").between(20,30), "90's KID") \
                   .when(col("age").between(30,40) , "80's KID") \
                   .otherwise("Old Person").alias("age_group")
                  ).show(5, False)

print("""
=============================
""")

user_df.select("*", expr(
    """
    CASE WHEN age <= 20 THEN "20\'s KID"
        WHEN age BETWEEN 20 AND 30 THEN "90\'s KID"
        WHEN age BETWEEN 30 AND 40 THEN "80\'s KID"
        ELSE "Old Person" END AS age_group
    """
)).show(5, truncate=False)


print("""
=============================
""")

# Adding multiple columns 

columns = {
    "new_age" : f.round(col("age") * 0.4, 2),
    "gender_abbrivation" : f.when(col("gender") == "M", "Male").otherwise("Female") 
}

user_df.withColumns(columns).show(5,truncate=False)

+---+---+------+----------+-------+----------+
|id |age|gender|occupation|zipcode|age_group |
+---+---+------+----------+-------+----------+
|1  |24 |M     |technician|85711  |90's KID  |
|2  |53 |F     |other     |94043  |Old Person|
|3  |23 |M     |writer    |32067  |90's KID  |
|4  |24 |M     |technician|43537  |90's KID  |
|5  |33 |F     |other     |15213  |80's KID  |
+---+---+------+----------+-------+----------+
only showing top 5 rows



+---+---+------+----------+-------+----------+
|id |age|gender|occupation|zipcode|age_group |
+---+---+------+----------+-------+----------+
|1  |24 |M     |technician|85711  |90's KID  |
|2  |53 |F     |other     |94043  |Old Person|
|3  |23 |M     |writer    |32067  |90's KID  |
|4  |24 |M     |technician|43537  |90's KID  |
|5  |33 |F     |other     |15213  |80's KID  |
+---+---+------+----------+-------+----------+
only showing top 5 rows



+---+---+------+----------+-------+----------+
|id |age|gender|occupation|zipcode|age_group |
+---+-

#### Adding the Literals (lit & typeLit)

In [73]:
from pyspark.sql.functions import lit

# non sequence literal value
user_df.withColumn("gender_abbrevation", f.when(col("gender") == "F", f.lit("Female")).otherwise(lit("Male")))\
.show(5, False)

print("""
=============================
""")

# sequence literal value
user_df.withColumn("gender_facilites", 
                   f.when(col("gender") == "F", f.lit(["free bus", "protection"]))\
                   .otherwise(lit(["job notifcation", "scholerships"])))\
.show(5, False)

+---+---+------+----------+-------+------------------+
|id |age|gender|occupation|zipcode|gender_abbrevation|
+---+---+------+----------+-------+------------------+
|1  |24 |M     |technician|85711  |Male              |
|2  |53 |F     |other     |94043  |Female            |
|3  |23 |M     |writer    |32067  |Male              |
|4  |24 |M     |technician|43537  |Male              |
|5  |33 |F     |other     |15213  |Female            |
+---+---+------+----------+-------+------------------+
only showing top 5 rows



+---+---+------+----------+-------+-------------------------------+
|id |age|gender|occupation|zipcode|gender_facilites               |
+---+---+------+----------+-------+-------------------------------+
|1  |24 |M     |technician|85711  |[job notifcation, scholerships]|
|2  |53 |F     |other     |94043  |[free bus, protection]         |
|3  |23 |M     |writer    |32067  |[job notifcation, scholerships]|
|4  |24 |M     |technician|43537  |[job notifcation, scholerships]|
|5

#### Renaming the Column

In [74]:
# Single Column Renamed
user_df.withColumnRenamed("id", "user_id").show(5, False)

print("""
=============================
""")

# Multiple Column Renamed
user_df.withColumnsRenamed({"id": "user_id", "age" : "new_age"}).show(5, False)


+-------+---+------+----------+-------+
|user_id|age|gender|occupation|zipcode|
+-------+---+------+----------+-------+
|1      |24 |M     |technician|85711  |
|2      |53 |F     |other     |94043  |
|3      |23 |M     |writer    |32067  |
|4      |24 |M     |technician|43537  |
|5      |33 |F     |other     |15213  |
+-------+---+------+----------+-------+
only showing top 5 rows



+-------+-------+------+----------+-------+
|user_id|new_age|gender|occupation|zipcode|
+-------+-------+------+----------+-------+
|1      |24     |M     |technician|85711  |
|2      |53     |F     |other     |94043  |
|3      |23     |M     |writer    |32067  |
|4      |24     |M     |technician|43537  |
|5      |33     |F     |other     |15213  |
+-------+-------+------+----------+-------+
only showing top 5 rows



#### Ordering the data i.e., ASC, DESC, NULLS FIRST and NULLS LAST

In [75]:
user_df.orderBy(col("age").asc(),
               col("zipcode").desc(),
               col("id").asc_nulls_first(),
               col("occupation").desc_nulls_last()
               ).show(5, False)

print("""
=============================
""")

user_df.sort(col("age").asc(),
               col("zipcode").desc(),
               col("id").asc_nulls_first(),
               col("occupation").desc_nulls_last()
               ).show(5, False)

+---+----+------+----------+-------+
|id |age |gender|occupation|zipcode|
+---+----+------+----------+-------+
|116|NULL|M     |healthcare|97232  |
|17 |NULL|M     |programmer|6355   |
|30 |7   |M     |student   |55436  |
|471|10  |M     |student   |77459  |
|289|11  |M     |none      |94619  |
+---+----+------+----------+-------+
only showing top 5 rows



+---+----+------+----------+-------+
|id |age |gender|occupation|zipcode|
+---+----+------+----------+-------+
|116|NULL|M     |healthcare|97232  |
|17 |NULL|M     |programmer|6355   |
|30 |7   |M     |student   |55436  |
|471|10  |M     |student   |77459  |
|289|11  |M     |none      |94619  |
+---+----+------+----------+-------+
only showing top 5 rows



#### String functions

In [76]:
user_df.withColumn('len_of_occupation', f.length(user_df['occupation'])).show(1, False)
print("""=============================""")
user_df.withColumn('len_of_occupation', f.substring(user_df['occupation'], 1, 5)).show(1, False)
print("""=============================""")
user_df.withColumn('len_of_occupation', f.regexp_replace(user_df['occupation'], 'a', 'z')).show(1, False)

+---+---+------+----------+-------+-----------------+
|id |age|gender|occupation|zipcode|len_of_occupation|
+---+---+------+----------+-------+-----------------+
|1  |24 |M     |technician|85711  |10               |
+---+---+------+----------+-------+-----------------+
only showing top 1 row

+---+---+------+----------+-------+-----------------+
|id |age|gender|occupation|zipcode|len_of_occupation|
+---+---+------+----------+-------+-----------------+
|1  |24 |M     |technician|85711  |techn            |
+---+---+------+----------+-------+-----------------+
only showing top 1 row

+---+---+------+----------+-------+-----------------+
|id |age|gender|occupation|zipcode|len_of_occupation|
+---+---+------+----------+-------+-----------------+
|1  |24 |M     |technician|85711  |technicizn       |
+---+---+------+----------+-------+-----------------+
only showing top 1 row



#### Date Functions

In [77]:
user_df.withColumn("dob", f.to_date(lit("2025-13-01"), "yyyy-dd-MM")).show(3, False)
print("""=============================""")
user_df.withColumn("current_date", f.current_date()).show(3, False)
print("""=============================""")

# convert the date into String
user_df.withColumn("todays_date", f.date_format(f.current_date(), 'dd/MM/yyyy')).show(3, False)

+---+---+------+----------+-------+----------+
|id |age|gender|occupation|zipcode|dob       |
+---+---+------+----------+-------+----------+
|1  |24 |M     |technician|85711  |2025-01-13|
|2  |53 |F     |other     |94043  |2025-01-13|
|3  |23 |M     |writer    |32067  |2025-01-13|
+---+---+------+----------+-------+----------+
only showing top 3 rows

+---+---+------+----------+-------+------------+
|id |age|gender|occupation|zipcode|current_date|
+---+---+------+----------+-------+------------+
|1  |24 |M     |technician|85711  |2025-09-13  |
|2  |53 |F     |other     |94043  |2025-09-13  |
|3  |23 |M     |writer    |32067  |2025-09-13  |
+---+---+------+----------+-------+------------+
only showing top 3 rows

+---+---+------+----------+-------+-----------+
|id |age|gender|occupation|zipcode|todays_date|
+---+---+------+----------+-------+-----------+
|1  |24 |M     |technician|85711  |13/09/2025 |
|2  |53 |F     |other     |94043  |13/09/2025 |
|3  |23 |M     |writer    |32067  |13/

#### Drop Null values

In [78]:
# Drop Nulls from all columns
user_df.na.drop() # user_df.drop_duplicates()

# Drop Nulls from specific column
user_df.na.drop(subset=["age", "occupation"]) # user_df.drop_duplicates(subset=["age", "occupation"])

# drop column 
user_df.withColumn("current_date" , lit(f.current_date())).drop("current_date").show(2, False)

# Drop duplicates
user_df.drop_duplicates()

+---+---+------+----------+-------+
|id |age|gender|occupation|zipcode|
+---+---+------+----------+-------+
|1  |24 |M     |technician|85711  |
|2  |53 |F     |other     |94043  |
+---+---+------+----------+-------+
only showing top 2 rows



DataFrame[id: int, age: int, gender: string, occupation: string, zipcode: int]

#### Union Vs Union ALL

In [79]:
emp_df_1 = user_df.limit(5)

emp_data = [[6,22,'M', 'Software', 89893], [1,24,'M', 'technician', 85711]]
emp_df_2 = spark.createDataFrame(data=emp_data, schema=user_df.columns)

# Union all example
emp_df_1.union(emp_df_2).show(7, False)

print("""=============================""")

# Union distinct example
emp_df_1.union(emp_df_2).distinct().show(7, False)

print("""=============================""")

# Union by Name
emp_df_1.unionByName(emp_df_2).show()

+---+---+------+----------+-------+
|id |age|gender|occupation|zipcode|
+---+---+------+----------+-------+
|1  |24 |M     |technician|85711  |
|2  |53 |F     |other     |94043  |
|3  |23 |M     |writer    |32067  |
|4  |24 |M     |technician|43537  |
|5  |33 |F     |other     |15213  |
|6  |22 |M     |Software  |89893  |
|1  |24 |M     |technician|85711  |
+---+---+------+----------+-------+

+---+---+------+----------+-------+
|id |age|gender|occupation|zipcode|
+---+---+------+----------+-------+
|3  |23 |M     |writer    |32067  |
|1  |24 |M     |technician|85711  |
|4  |24 |M     |technician|43537  |
|5  |33 |F     |other     |15213  |
|2  |53 |F     |other     |94043  |
|6  |22 |M     |Software  |89893  |
+---+---+------+----------+-------+

+---+---+------+----------+-------+
| id|age|gender|occupation|zipcode|
+---+---+------+----------+-------+
|  1| 24|     M|technician|  85711|
|  2| 53|     F|     other|  94043|
|  3| 23|     M|    writer|  32067|
|  4| 24|     M|technician

#### Aggregations

In [80]:
# Avg age per Gender
user_df.groupby(col("gender"))\
.agg(f.round(f.avg(user_df.age),2).alias("avg_age_per_gender"))\
.show()

# count of occupations 
user_df.groupby(user_df.occupation)\
.agg(f.count(user_df.id).alias("count_of_occupation"))\
.orderBy(col("count_of_occupation").desc())\
.show(5, False)

# distinct column values
user_df.select(user_df.gender).distinct().show()

+------+------------------+
|gender|avg_age_per_gender|
+------+------------------+
|     F|             33.83|
|     M|             34.15|
|  NULL|              28.0|
+------+------------------+

+-------------+-------------------+
|occupation   |count_of_occupation|
+-------------+-------------------+
|student      |196                |
|other        |105                |
|educator     |95                 |
|administrator|78                 |
|engineer     |66                 |
+-------------+-------------------+
only showing top 5 rows

+------+
|gender|
+------+
|     F|
|     M|
|  NULL|
+------+



#### 
Window Functions

In [81]:
from pyspark.sql.window import Window

window = Window.partitionBy(col("gender")).orderBy(col("age"))

# Dense Rank
user_df.withColumn("rnk", f.dense_rank().over(window=window)).show(5,False)

# Max Age
new_window = window.rangeBetween(Window.unboundedPreceding, Window.currentRow)

user_df.withColumn("max_age_until_now", 
                   f.round(f.avg(col("age"))
                   .over(window=Window.partitionBy(col("gender"))
                         .orderBy(f.desc("age"))
                         .rowsBetween(Window.unboundedPreceding, 0))
                  ,2)).show(5, False)

+---+---+------+----------+-------+---+
|id |age|gender|occupation|zipcode|rnk|
+---+---+------+----------+-------+---+
|609|13 |F     |student   |55106  |1  |
|674|13 |F     |student   |55337  |1  |
|206|14 |F     |student   |53115  |2  |
|813|14 |F     |student   |2136   |2  |
|887|14 |F     |student   |27249  |2  |
+---+---+------+----------+-------+---+
only showing top 5 rows

+---+---+------+-------------+-------+-----------------+
|id |age|gender|occupation   |zipcode|max_age_until_now|
+---+---+------+-------------+-------+-----------------+
|860|70 |F     |retired      |48322  |70.0             |
|266|62 |F     |administrator|78756  |66.0             |
|131|59 |F     |administrator|15237  |63.67            |
|754|59 |F     |librarian    |62901  |62.5             |
|591|57 |F     |librarian    |92093  |61.4             |
+---+---+------+-------------+-------+-----------------+
only showing top 5 rows



#### Partitions 
1. **Repartition:** The repartition() can be used to increase or decrease the number of partitions, 
but it involves heavy data shuffling across the cluster.
2. **coalesce:** coalesce() can be used only to decrease the number of partitions. In most of the cases, coalesce() does not trigger a shuffle.

In [82]:
# get No of Partitions
print(user_df.rdd.getNumPartitions())

# Increase the partitions 
print(user_df.repartition(5).rdd.getNumPartitions())

# Decrease the partitions 
print(user_df.repartition(4).rdd.getNumPartitions())

# Identify the spark_partition_id value
user_df.repartition(100, col("age")).withColumn("partition_id", f.spark_partition_id()).show(30, False)

1
5
4
+---+---+------+-------------+-------+------------+
|id |age|gender|occupation   |zipcode|partition_id|
+---+---+------+-------------+-------+------------+
|63 |31 |M     |marketing    |75240  |2           |
|91 |55 |M     |marketing    |1913   |2           |
|95 |31 |M     |administrator|10707  |2           |
|115|31 |M     |engineer     |17110  |2           |
|131|59 |F     |administrator|15237  |2           |
|134|31 |M     |programmer   |80236  |2           |
|145|31 |M     |entertainment|NULL   |2           |
|172|55 |M     |marketing    |22207  |2           |
|197|55 |M     |technician   |75094  |2           |
|224|31 |F     |educator     |43512  |2           |
|269|31 |F     |librarian    |43201  |2           |
|295|31 |M     |educator     |50325  |2           |
|315|31 |M     |educator     |18301  |2           |
|388|31 |M     |other        |36106  |2           |
|413|55 |M     |educator     |78212  |2           |
|418|55 |F     |none         |21206  |2           |
|426|5

#### Joins

In [83]:
dept_data = [[55105, 'Dusseldorf'], [27713, 'Essen'], [32114, 'Dortmund'], [18301, 'Cologne'], [21010, 'Berlin']]
dept_schema = ['zip_code', 'city']

dept_df = spark.createDataFrame(data=dept_data, schema=dept_schema)

# Inner
user_dept_df = user_df.join(dept_df, user_df.zipcode == dept_df.zip_code, how='inner').show(5)

#left
user_dept_df = user_df.join(dept_df, user_df.zipcode == dept_df.zip_code, how='left').show(5)

#Right
user_dept_df = user_df.join(dept_df, user_df.zipcode == dept_df.zip_code, how='right').show(100)

#anti
user_dept_df = user_df.join(dept_df, user_df.zipcode == dept_df.zip_code, how='anti').show(5)

#anti
user_dept_df = user_df.join(dept_df, user_df.zipcode == dept_df.zip_code, how='semi').show(100)

+---+---+------+----------+-------+--------+----------+
| id|age|gender|occupation|zipcode|zip_code|      city|
+---+---+------+----------+-------+--------+----------+
|653| 31|     M| executive|  55105|   55105|Dusseldorf|
|421| 38|     F|programmer|  55105|   55105|Dusseldorf|
|352| 37|     F|programmer|  55105|   55105|Dusseldorf|
|196| 49|     M|    writer|  55105|   55105|Dusseldorf|
| 52| 18|     F|   student|  55105|   55105|Dusseldorf|
+---+---+------+----------+-------+--------+----------+
only showing top 5 rows

+---+---+------+----------+-------+--------+----+
| id|age|gender|occupation|zipcode|zip_code|city|
+---+---+------+----------+-------+--------+----+
|  3| 23|     M|    writer|  32067|    NULL|NULL|
|  5| 33|     F|     other|  15213|    NULL|NULL|
|  4| 24|     M|technician|  43537|    NULL|NULL|
|  1| 24|     M|technician|  85711|    NULL|NULL|
|  6| 42|     M| executive|  98101|    NULL|NULL|
+---+---+------+----------+-------+--------+----+
only showing top 5 ro

#### Reading the CSV data

We are focusing on permissive mode while reading the CSV file, this mode allows us to deal with corrupt records during the parsing. PERMISSIVE sets other fields to null when it meets a corrupted record and stores the malformed string into a new field called _corrupt_record column, this column name can be changed with columnNameOfCorruptRecord  config.

In [84]:
user_schema = user_schema.add(StructField(name="_corrupt_record", dataType=StringType(), nullable=True))

user_csv_df = spark.read.csv(path="data/u.user", 
                             header="false", 
                             inferSchema="false", 
                             mode="Permisive", 
                             sep="|", 
                             schema=user_schema)

In [85]:
user_csv_df.where(~user_csv_df._corrupt_record.isNull()).show()

+---+----+------+-------------+-------+--------------------+
| id| age|gender|   occupation|zipcode|     _corrupt_record|
+---+----+------+-------------+-------+--------------------+
| 17|NULL|     M|   programmer|   6355|17|T|M|programmer...|
| 28|  32|     M|       writer|   NULL| 28|32|M|writer|NULL|
| 74|  39|     M|    scientist|   NULL|74|39|M|scientist...|
|116|NULL|     M|   healthcare|  97232|116|NULL|M|health...|
|145|  31|     M|entertainment|   NULL|145|31|M|entertai...|
|167|  37|     M|        other|   NULL|167|37|M|other|L9G2B|
|201|  27|     M|       writer|   NULL|201|27|M|writer|E...|
|213|  33|     M|    executive|   NULL|213|33|M|executiv...|
|333|  47|     M|        other|   NULL|333|47|M|other|V0R2M|
|458|  47|     M|   technician|   NULL|458|47|M|technici...|
|490|  29|     F|       artist|   NULL|490|29|F|artist|V...|
|578|  31|     M|administrator|   NULL|578|31|M|administ...|
|594|  46|     M|     educator|   NULL|594|46|M|educator...|
|599|  22|     F|      s

#### Reading the Parquet Data

Apache Parquet is a columnar file format that supports the compression and schema enforcement. Apache spark default file format is Parquet and it is widly used in OLAP use cases and modern day data engineering services such as BigQuery, AWS Athena, Hive, snowflake etc.

<img src="https://miro.medium.com/v2/resize:fit:1126/0*vWHK7tn_sxWl_U6k.png" width="300" height="300">

#### How Parquet stores the data ? 
Inside the parquet file, data will be stored in 4 parts
1. Row Group:
    > Horizontal partition of the data \
    > Each row group contains the data for all the columns but only for a subset of rows  \
    > It enables parallel reads  \
2. Column Chunk:
    > It holds the values of a single column in a Row Group  \
    > It enables the column level access  \
3. Pages:
    > Column chunks are divided into pages  \
    > Data pages contain actual data  \
    > Pages are individually compressed, improving memory usage  \
4. Footer:
    > Stored at the end of the Parquet file  \
    > Contains: Column names and data types, Row group info (offset, size), Min/max stats for columns and Compression and encoding info  \
    > Enables: Predicate pushdown, Schema discovery, and Column/row group skipping  \

In [86]:
# Create the parquet file
# user_csv_df.repartition(3).write.partitionBy("occupation").mode("overwrite").save("data/user.parquet")

In [87]:
# read the parquet file
user_parquet_df = spark.read.parquet("data/user.parquet")

In [88]:
select_cols = ["id", "age", "gender", "zipcode", "occupation"]
user_parquet_df.where(col("occupation").rlike("[^0-9]")).select(*select_cols).show(10)

+---+---+------+-------+----------+
| id|age|gender|zipcode|occupation|
+---+---+------+-------+----------+
|262| 19|     F|  78264|   student|
|140| 30|     F|  32250|   student|
|584| 25|     M|  27511|   student|
|257| 17|     M|  77005|   student|
|372| 25|     F|  66046|   student|
|586| 20|     M|  79508|   student|
|259| 21|     M|  48823|   student|
|632| 18|     M|  55454|   student|
| 73| 24|     M|  41850|   student|
|361| 22|     M|  44074|   student|
+---+---+------+-------+----------+
only showing top 10 rows



#### Reading the JSON data 

Json is nothing but "java script object Notation", is a human readable text file format. Normally used for Data exchange.
Generally it uses 2 structures Object, Array.
1. Objects: Unordered sets of key-value pairs, enclosed in curly braces {}. Each key is a string, and the value can be a primitive data type (string, number, boolean, null), another object, or an arra
2. Arrays: Ordered lists of values, enclosed in square brackets []. Values can be any valid JSON data type.

Json don't support the comments. 

In [89]:
json_df = spark.read.option("multiline", True).json("data/json_example.json")


In [90]:
json_df.show(truncate=False)

+--------------------+---+--------+------------------------+---------+--------+
|address             |age|city    |courses                 |isStudent|name    |
+--------------------+---+--------+------------------------+---------+--------+
|{123 Main St, 10001}|30 |New York|[Math, Physics, History]|false    |John Doe|
+--------------------+---+--------+------------------------+---------+--------+



In [91]:
json_df.printSchema()

root
 |-- address: struct (nullable = true)
 |    |-- street: string (nullable = true)
 |    |-- zip: string (nullable = true)
 |-- age: long (nullable = true)
 |-- city: string (nullable = true)
 |-- courses: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- isStudent: boolean (nullable = true)
 |-- name: string (nullable = true)



In [92]:
json_df.select(col("address.street").alias("street"),
               col("address.zip").alias("zip"),
               col("age"),col("city"),col("isStudent"), col("name"),
               f.explode_outer("courses")).show(truncate=False)

+-----------+-----+---+--------+---------+--------+-------+
|street     |zip  |age|city    |isStudent|name    |col    |
+-----------+-----+---+--------+---------+--------+-------+
|123 Main St|10001|30 |New York|false    |John Doe|Math   |
|123 Main St|10001|30 |New York|false    |John Doe|Physics|
|123 Main St|10001|30 |New York|false    |John Doe|History|
+-----------+-----+---+--------+---------+--------+-------+



In [93]:
from pyspark.sql.types import StructType

new_schema = "address STRUCT<street: STRING, zip: STRING>, age INT, city STRING, courses ARRAY<STRING>, isstudent BOOLEAN, name STRING"

In [94]:
json_df_1 = spark.read.option("multiline", True).text("data/json_example.json")

In [95]:
json_df_1.show()

+--------------------+
|               value|
+--------------------+
|                   {|
|  "name": "John D...|
|          "age": 30,|
|  "city": "New Yo...|
|  "isStudent": fa...|
|        "address": {|
|    "street": "12...|
|      "zip": "10001"|
|                  },|
|        "courses": [|
|             "Math",|
|          "Physics",|
|           "History"|
|                   ]|
|                   }|
+--------------------+



In [96]:
json_string_df = json_df_1.select(f.collect_list("value").alias("lines")) \
    .withColumn("json_str", f.concat_ws("", "lines"))

In [97]:
json_string_df.show(truncate=False)

+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|lines                                                                                                                                                                                                                             |json_str                                                                                                                                                                                            |
+---------------------------------------------------------------------------------------------------------------------------------------------------

#### Write Modes in Spark

**Specifies the behavior when data or table already exists.**
Options include:
1. append: Append contents of this DataFrame to existing data.
2. overwrite: Overwrite existing data.
3. error or errorifexists: Throw an exception if data already exists.
4. ignore: Silently ignore this operation if data already exists.

In [98]:
# Default parallelism in spark

print(spark.sparkContext.defaultParallelism) # 12
print(spark.sparkContext.defaultMinPartitions) #2

12
2


In [99]:
new_user_df = user_df.select("*")
new_user_df.show(2)

+---+---+------+----------+-------+
| id|age|gender|occupation|zipcode|
+---+---+------+----------+-------+
|  1| 24|     M|technician|  85711|
|  2| 53|     F|     other|  94043|
+---+---+------+----------+-------+
only showing top 2 rows



In [100]:
# repartitions 
new_user_df = new_user_df.repartition(5)
new_user_df.rdd.getNumPartitions()

5

In [101]:
# Writing the data as parquet with default partitions
new_user_df.write.partitionBy("occupation").mode("overwrite").save("data/user_parquet.parquet")

                                                                                

#### deployment modes in Spark
##### In Apache Spark, there are two primary deployment modes:
1.	**Client Mode:** In this mode, the Driver Program runs on the machine that submits the application (typically the edge node or master node, outside the cluster). The cluster manager launches the executors on the cluster nodes, but the driver stays on the client machine. \
    a. If the client session (e.g., your terminal or notebook) is terminated, the Spark application will also terminate, since the driver is no longer running. \
    b. This mode is typically used for interactive sessions, debugging, or development purposes. 
2.	**Cluster Mode:** In cluster mode, the Driver Program is launched inside the cluster by the cluster manager, alongside the executors. \
    a. This is the recommended approach for production workloads, as it allows the application to continue running independently of the client process that submitted it. \
	b. Since the driver and executors run within the same cluster, data locality is better and there’s reduced overhead from network communication between driver and executors. \
	c. Even if the client disconnects or is terminated, the Spark application continues to run.

### UDF

UDF (User Defined Function): custom function written in Python that you can apply to Spark DataFrame columns when built-in Spark functions (pyspark.sql.functions) are not sufficient.

✅ Why Use UDFs?  \
	•	To apply custom Python logic to each row or column.  \
	•	When there’s no equivalent Spark function available (e.g., complex text processing, custom regex, external library logic).

⚠️ UDF Performance Caveats \
	•	Slower than native functions — UDFs break Spark’s optimization (Catalyst engine). \
	•	Serialization overhead — data is sent between JVM (Spark) and Python (via Py4J).  \
	•	Harder to optimize — no column pruning or predicate pushdown.

🏎️ Better Alternatives to UDFs (when possible): \
Use native Spark functions from pyspark.sql.functions like: \
	•	when, col, regexp_extract, split, concat_ws \
	•	expr(), withColumn(), etc. 

✅ These are faster and Spark-optimized.

🚀 Pandas UDF (Vectorized UDF) \
If you must use custom Python logic, prefer Pandas UDFs (aka vectorized UDFs), which are faster and optimized for batch processing.

In [102]:
from pyspark.sql.functions import udf

# Way-1
@udf(returnType=StringType())
def camelCase(word: str):
    return word[0].upper() +word[1:]

# Way-2
def camelCase_v1(word: str):
    return word[0].upper() +word[1:]

camelCase_v1 = udf(camelCase_v1, StringType())

# Way-3
def camelCase_v2_fn(word: str):
    return word[0].upper() +word[1:]

camelCase_v2 = udf(lambda a: camelCase_v2_fn(a), StringType())

# Way-4
spark.udf.register("camelCase_v3", camelCase_v2_fn, StringType())

<function __main__.camelCase_v2_fn(word: str)>

In [103]:
user_df.select(user_df.occupation, 
              camelCase(user_df.occupation).alias("new_name"),
              camelCase_v1(user_df.occupation).alias("new_name_v1"),
              camelCase_v2(user_df.occupation).alias("new_name_v2")).show(5)

user_df.createOrReplaceTempView("user_df")

spark.sql("select occupation, camelCase_v3(occupation) as new_name_v3 from user_df").show(5)

+----------+----------+-----------+-----------+
|occupation|  new_name|new_name_v1|new_name_v2|
+----------+----------+-----------+-----------+
|technician|Technician| Technician| Technician|
|     other|     Other|      Other|      Other|
|    writer|    Writer|     Writer|     Writer|
|technician|Technician| Technician| Technician|
|     other|     Other|      Other|      Other|
+----------+----------+-----------+-----------+
only showing top 5 rows

+----------+-----------+
|occupation|new_name_v3|
+----------+-----------+
|technician| Technician|
|     other|      Other|
|    writer|     Writer|
|technician| Technician|
|     other|      Other|
+----------+-----------+
only showing top 5 rows



#### Explain VS toDebugString() / Lineage Topic

RDD lineage is the sequence of transformations that defines how an RDD is derived from other RDDs, forming a logical execution graph that Spark uses for fault recovery and optimization. For DataFrames, the transformations are represented as a logical plan and optimized by Spark’s Catalyst optimizer.

In [104]:
user_df.explain()

== Physical Plan ==
FileScan csv [id#2819,age#2820,gender#2821,occupation#2822,zipcode#2823] Batched: false, DataFilters: [], Format: CSV, Location: InMemoryFileIndex(1 paths)[file:/Users/saimammahi/Documents/Work/Interview/Pyspark-Playground/dat..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct<id:int,age:int,gender:string,occupation:string,zipcode:int>




In [105]:
user_df.rdd.toDebugString()

b'(1) MapPartitionsRDD[167] at javaToPython at NativeMethodAccessorImpl.java:0 []\n |  MapPartitionsRDD[166] at javaToPython at NativeMethodAccessorImpl.java:0 []\n |  SQLExecutionRDD[165] at javaToPython at NativeMethodAccessorImpl.java:0 []\n |  MapPartitionsRDD[164] at javaToPython at NativeMethodAccessorImpl.java:0 []\n |  FileScanRDD[163] at javaToPython at NativeMethodAccessorImpl.java:0 []'

#### Optimize the Shuffles in Spark

In [106]:
customer_orders_df = spark.read.format("parquet").load("data/user.parquet")

In [107]:
customer_orders_df.printSchema()

root
 |-- id: integer (nullable = true)
 |-- age: integer (nullable = true)
 |-- gender: string (nullable = true)
 |-- zipcode: integer (nullable = true)
 |-- _corrupt_record: string (nullable = true)
 |-- occupation: string (nullable = true)



In [108]:
spark.conf.get("spark.sql.shuffle.partitions")

'200'

In [109]:
spark.conf.set("spark.sql.shuffle.partitions", 200)

In [110]:
spark.conf.get("spark.sql.shuffle.partitions")

'200'

In [111]:
customer_orders_df_1 = customer_orders_df.select(f.count_distinct(f.col("occupation")))

### Cache VS Persist

- **Caching** and **persisting** are techniques to improve performance for **iterative or interactive Spark workloads** by storing intermediate results.

- These methods avoid recomputing expensive transformations across multiple actions.

#### 1. Default Behavior of `cache()`

- For **RDDs**:
  - `cache()` uses the default storage level: **`MEMORY_ONLY`**.
  
- For **DataFrames/Datasets**:
  - `cache()` defaults to **`MEMORY_AND_DISK`**.

In both cases, `cache()` is a shorthand for `persist()` using the default level.

---

#### 2. Using `persist()` with Custom Storage Levels

- `persist()` allows specifying a storage level explicitly.
- Available levels include:
  - `MEMORY_ONLY`
  - `MEMORY_ONLY_SER` (serialized in memory): Only for Java and Scala, in Pyspark we can't use this
  - `MEMORY_AND_DISK`
  - `MEMORY_AND_DISK_SER`: Only for Java and Scala, in Pyspark we can't use this
  - `DISK_ONLY`
  - And replication variants like `MEMORY_ONLY_2` (replicated twice)

Use `persist()` when you need more flexibility over memory/disk usage or serialization.

---

#### 3. Storage Levels: Trade-offs

| Storage Level         | Memory Use | CPU Use         | In Memory | On Disk | Serialized | Recompute on Failure |
|-----------------------|------------|------------------|-----------|---------|------------|----------------------|
| `MEMORY_ONLY`         | High       | Low              | Yes       | No      | No         | Yes                  |
| `MEMORY_ONLY_SER`     | Low        | High (serialize) | Yes       | No      | Yes        | Yes                  |
| `MEMORY_AND_DISK`     | High       | Medium           | Some      | Some    | No         | No                   |
| `MEMORY_AND_DISK_SER` | Low        | High             | Some      | Some    | Yes        | No                   |
| `DISK_ONLY`           | Low        | High (I/O)       | No        | Yes     | Yes        | No                   |

---

#### 4. Usage Guidelines & Behavior

- **Lazy evaluation**: `cache()` and `persist()` do not immediately store data—it happens when an **action** is performed (e.g., `.count()`, `.show()`).
- **Viewing storage**: Use the Spark UI under the “Storage” tab to monitor cached datasets.
- **Removing cached data**: Use `.unpersist()` to free up memory when a DataFrame or RDD is no longer needed.

---

In [112]:
user_df = spark.read.parquet('data/user.parquet')

In [113]:
user_df.show(2,False)

+---+---+------+-------+---------------+----------+
|id |age|gender|zipcode|_corrupt_record|occupation|
+---+---+------+-------+---------------+----------+
|262|19 |F     |78264  |NULL           |student   |
|140|30 |F     |32250  |NULL           |student   |
+---+---+------+-------+---------------+----------+
only showing top 2 rows



In [114]:
# Cache the df and it stores the data in MEMORY_AND_DISK
new_user_df = user_df.cache()

In [115]:
from pyspark import StorageLevel
new_user_df_2 = new_user_df.persist(StorageLevel.MEMORY_ONLY)

In [116]:
new_user_df_2

DataFrame[id: int, age: int, gender: string, zipcode: int, _corrupt_record: string, occupation: string]

In [117]:
# Unpersist the cache df
new_user_df.unpersist()

DataFrame[id: int, age: int, gender: string, zipcode: int, _corrupt_record: string, occupation: string]

In [118]:
# Unpersist the persisting df
new_user_df_2.unpersist()

DataFrame[id: int, age: int, gender: string, zipcode: int, _corrupt_record: string, occupation: string]

In [119]:
# Removing the cache on cluster level
spark.catalog.clearCache()

### Broadcast Variable and Accumlators

#### What is a Broadcast Variable?
 - **A broadcast variable allows you to share a read-only variable across all Spark executors efficiently**.
 - Useful when you have a large lookup table or reference dataset that needs to be used in multiple tasks.
 - Spark sends the variable once to each node rather than with every task, saving network I/O.

When to Use:
- Joining a small dataset with a large RDD/DataFrame (map-side join).
- Any read-only data that is reused across tasks.

In [120]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, udf
from pyspark import SparkContext

# Initialize Spark
spark = SparkSession.builder.appName("BroadcastExample").getOrCreate()

# Example DataFrame
data = [("Alice", "US"), ("Bob", "UK"), ("Charlie", "US"), ("David", "CA")]
df = spark.createDataFrame(data, ["name", "country"])

# Small lookup table: country -> continent
country_continent = {"US": "North America", "UK": "Europe", "CA": "North America"}

# Broadcast the lookup table
broadcast_var = spark.sparkContext.broadcast(country_continent)

# Define UDF using broadcast variable
def get_continent(country):
    return broadcast_var.value.get(country, "Unknown")

get_continent_udf = udf(get_continent)

# Apply UDF to DataFrame
df.withColumn("continent", get_continent_udf(col("country"))).show()

+-------+-------+-------------+
|   name|country|    continent|
+-------+-------+-------------+
|  Alice|     US|North America|
|    Bob|     UK|       Europe|
|Charlie|     US|North America|
|  David|     CA|North America|
+-------+-------+-------------+



# 🔀 Spark Join Strategies: Shuffle Hash Join vs Sort Merge Join vs Broadcast Join

In Spark, when performing **joins** between two DataFrames (or RDDs), the **execution strategy** chosen by the Catalyst optimizer can have a big impact on performance.  
Here are the three main join strategies:

---

## 1️⃣ Shuffle Hash Join (SHJ)

### ✅ How it Works
- Both input DataFrames are **shuffled** by the join key into partitions.
- A **hash table** is built on the smaller side within each partition.
- The other side scans and probes the hash table to find matches.

### 📌 Characteristics
- Best when one side of the join is **small enough to fit in memory**.
- **Hash-based matching** makes it fast for **equality joins** (`=`, `IN`).
- Requires **shuffling** both datasets.

### ⚠️ Limitations
- Memory intensive (hash table must fit in memory).
- Not efficient for **large-to-large joins**.

---

## 2️⃣ Sort Merge Join (SMJ)

### ✅ How it Works
1. Both input DataFrames are **shuffled** by the join key.
2. Within each partition, data is **sorted**.
3. A **merge algorithm** (like merge-sort) scans both sorted partitions to find matches.

### 📌 Characteristics
- Scales better for **very large datasets**.
- Works with **non-equality joins** (`<`, `<=`, `>`, etc.), unlike hash join.
- Efficient if data is **already sorted** or partitioned.

### ⚠️ Limitations
- Sorting step is expensive (high CPU cost).
- Requires large shuffle operations.

---

## 3️⃣ Broadcast Hash Join (BHJ)

### ✅ How it Works
- Spark **broadcasts the smaller DataFrame** to all executors.
- Each executor keeps a **copy of the small DataFrame** in memory.
- The join runs as a **map-side join**: no shuffle required.

### 📌 Characteristics
- Extremely fast when one table is **small (e.g., < 10MB by default)**.
- Avoids expensive shuffles.
- Great for **star schema joins** (fact table + dimension table).

### ⚠️ Limitations
- Only works if one dataset is **small enough to fit in executor memory**.
- Can cause **OOM errors** if broadcast size is too big.

---

## 🔑 When to Use Which?

| Join Type        | Best Scenario                                    | Avoid When                                    |
|------------------|--------------------------------------------------|-----------------------------------------------|
| **Shuffle Hash** | One dataset is moderately small but not tiny      | Both datasets are very large                  |
| **Sort Merge**   | Both datasets are very large, non-equality joins | One dataset can fit in memory (use Broadcast) |
| **Broadcast**    | One dataset is very small (< 10MB)               | Dataset is large → risk of OOM                |



In [121]:
## ⚡ Spark Join Strategy Examples (PySpark)

from pyspark.sql import SparkSession

spark = SparkSession.builder.appName("JoinStrategies").getOrCreate()

# Example DataFrames
data1 = [(1, "A"), (2, "B"), (3, "C")]
data2 = [(1, "X"), (2, "Y"), (3, "Z")]

df1 = spark.createDataFrame(data1, ["id", "val1"])
df2 = spark.createDataFrame(data2, ["id", "val2"])
result = df1.join(df2, "id", "inner")
result.explain(True)

== Parsed Logical Plan ==
'Join UsingJoin(Inner, [id])
:- LogicalRDD [id#4556L, val1#4557], false
+- LogicalRDD [id#4560L, val2#4561], false

== Analyzed Logical Plan ==
id: bigint, val1: string, val2: string
Project [id#4556L, val1#4557, val2#4561]
+- Join Inner, (id#4556L = id#4560L)
   :- LogicalRDD [id#4556L, val1#4557], false
   +- LogicalRDD [id#4560L, val2#4561], false

== Optimized Logical Plan ==
Project [id#4556L, val1#4557, val2#4561]
+- Join Inner, (id#4556L = id#4560L)
   :- Filter isnotnull(id#4556L)
   :  +- LogicalRDD [id#4556L, val1#4557], false
   +- Filter isnotnull(id#4560L)
      +- LogicalRDD [id#4560L, val2#4561], false

== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- Project [id#4556L, val1#4557, val2#4561]
   +- SortMergeJoin [id#4556L], [id#4560L], Inner
      :- Sort [id#4556L ASC NULLS FIRST], false, 0
      :  +- Exchange hashpartitioning(id#4556L, 200), ENSURE_REQUIREMENTS, [plan_id=4447]
      :     +- Filter isnotnull(id#4556L)
      :        

In [122]:
# Force Sort-Merge by disabling broadcast
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", 100000000)

result = df1.join(df2, "id", "inner")
result.explain(True)  # See physical plan → SortMergeJoin

== Parsed Logical Plan ==
'Join UsingJoin(Inner, [id])
:- LogicalRDD [id#4556L, val1#4557], false
+- LogicalRDD [id#4560L, val2#4561], false

== Analyzed Logical Plan ==
id: bigint, val1: string, val2: string
Project [id#4556L, val1#4557, val2#4561]
+- Join Inner, (id#4556L = id#4560L)
   :- LogicalRDD [id#4556L, val1#4557], false
   +- LogicalRDD [id#4560L, val2#4561], false

== Optimized Logical Plan ==
Project [id#4556L, val1#4557, val2#4561]
+- Join Inner, (id#4556L = id#4560L)
   :- Filter isnotnull(id#4556L)
   :  +- LogicalRDD [id#4556L, val1#4557], false
   +- Filter isnotnull(id#4560L)
      +- LogicalRDD [id#4560L, val2#4561], false

== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- Project [id#4556L, val1#4557, val2#4561]
   +- SortMergeJoin [id#4556L], [id#4560L], Inner
      :- Sort [id#4556L ASC NULLS FIRST], false, 0
      :  +- Exchange hashpartitioning(id#4556L, 200), ENSURE_REQUIREMENTS, [plan_id=4478]
      :     +- Filter isnotnull(id#4556L)
      :        

In [123]:
from pyspark.sql.functions import broadcast

result = df1.join(broadcast(df2), "id", "inner")
result.explain(True)  # See physical plan → BroadcastHashJoin

== Parsed Logical Plan ==
'Join UsingJoin(Inner, [id])
:- LogicalRDD [id#4556L, val1#4557], false
+- ResolvedHint (strategy=broadcast)
   +- LogicalRDD [id#4560L, val2#4561], false

== Analyzed Logical Plan ==
id: bigint, val1: string, val2: string
Project [id#4556L, val1#4557, val2#4561]
+- Join Inner, (id#4556L = id#4560L)
   :- LogicalRDD [id#4556L, val1#4557], false
   +- ResolvedHint (strategy=broadcast)
      +- LogicalRDD [id#4560L, val2#4561], false

== Optimized Logical Plan ==
Project [id#4556L, val1#4557, val2#4561]
+- Join Inner, (id#4556L = id#4560L), rightHint=(strategy=broadcast)
   :- Filter isnotnull(id#4556L)
   :  +- LogicalRDD [id#4556L, val1#4557], false
   +- Filter isnotnull(id#4560L)
      +- LogicalRDD [id#4560L, val2#4561], false

== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- Project [id#4556L, val1#4557, val2#4561]
   +- BroadcastHashJoin [id#4556L], [id#4560L], Inner, BuildRight, false
      :- Filter isnotnull(id#4556L)
      :  +- Scan Existin