# Module 4 Notebook 3: Performance Tuning Concepts

**Objective:** Explore and demonstrate the syntax and concepts behind common techniques for optimizing the performance of PySpark ML workloads, particularly for large-scale data.


In [0]:
# Necessary imports
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, count, sum, avg, min, datediff, first, lit, when, expr, unix_timestamp, broadcast
from pyspark.sql.window import Window
from pyspark.storagelevel import StorageLevel
import math


## 1. Data Preparation (Baseline Reference)

First, let's quickly reload and prepare the `regression_base_data` DataFrame as we did in M4N1. This will serve as the input for demonstrating the optimization techniques.

In [0]:
# Load raw data (assuming tables exist from previous notebooks)
interactions_df = spark.table("ecommerce.interactions")
customers_df = spark.table("ecommerce.customers")
products_df = spark.table("ecommerce.products")

# --- Minimal Feature Engineering Logic (from M4N1) --- 
# We only need enough to join and filter for the regression base
window_spec = Window.partitionBy("customer_id", "product_id").orderBy("timestamp")

customer_product_agg = interactions_df.groupBy("customer_id", "product_id").agg(
        count(when(col("interaction_type") == "purchase", 1)).alias("purchase_count") # Need purchase_count to filter
    )

# --- Join Data --- 
base_df = customer_product_agg.join(customers_df, "customer_id").join(products_df, "product_id")

# --- Filter for Regression Base --- 
regression_base_data = base_df.filter(col("purchase_count") > 0).select("customer_id", "product_id", "age", "country", "price") # Select a few cols for demo

print("Created baseline DataFrame `regression_base_data` for demonstration.")
regression_base_data.limit(5).display()


Created baseline DataFrame `regression_base_data` for demonstration.


customer_id,product_id,age,country,price
3245,252,32,US,52.16
5843,115,18,CA,38.71
8519,240,27,JP,414.86
7752,980,38,FR,74.04
317,682,43,US,177.74


## 2. Optimization Technique 1: Caching (`.cache()` / `.persist()`)

**Concept:** Caching stores the result of a DataFrame computation in memory (and potentially disk). If you use the same DataFrame multiple times in subsequent actions, Spark can reuse the cached result instead of recomputing it from scratch. This is extremely beneficial for iterative algorithms (like many ML training loops) or when branching off multiple analyses from one expensive data preparation step, **especially with large datasets**.

- `.cache()` is a shortcut for `.persist(StorageLevel.MEMORY_ONLY)`.
- `.persist(level)` allows specifying different storage levels (e.g., `MEMORY_AND_DISK`, `DISK_ONLY`). `MEMORY_AND_DISK` is often a good default, spilling to disk if the data doesn't fit entirely in memory.

**Demonstration:**


In [0]:
# Demonstrate persisting the DataFrame
print("Persisting regression_base_data...")
cached_data = regression_base_data.persist(StorageLevel.MEMORY_AND_DISK)

# Explain laziness: Persist marks it, but caching happens on first action
print("Triggering action to populate cache (e.g., count)...")
cache_count = cached_data.count()
print(f"DataFrame count (cached): {cache_count}")

# Explain reuse: Any subsequent action on `cached_data` *should* be faster on large data
print("Simulating reuse (e.g., show). On large data, this would reuse the cache.")
cached_data.show(5)

# Explain importance of unpersisting
print("Unpersisting the data to free up resources...")
cached_data.unpersist()
print("Data unpersisted.")


Persisting regression_base_data...
Triggering action to populate cache (e.g., count)...
DataFrame count (cached): 25238
Simulating reuse (e.g., show). On large data, this would reuse the cache.
+-----------+----------+---+-------+------+
|customer_id|product_id|age|country| price|
+-----------+----------+---+-------+------+
|       3245|       252| 32|     US| 52.16|
|       5843|       115| 18|     CA| 38.71|
|       8519|       240| 27|     JP|414.86|
|       7752|       980| 38|     FR| 74.04|
|        317|       682| 43|     US|177.74|
+-----------+----------+---+-------+------+
only showing top 5 rows
Unpersisting the data to free up resources...
Data unpersisted.


## 3. Optimization Technique 2: Data Partitioning (`.repartition()` / `.coalesce()`)

**Concept:** Spark processes data in parallel using partitions. The number of partitions can significantly impact performance.
- **Too few partitions:** Limits parallelism, executors might sit idle.
- **Too many partitions:** Scheduling overhead for tiny tasks, potentially inefficient data transfer.

- `.repartition(N)`: Changes the number of partitions to `N`. **Involves a full shuffle** of the data across the network, which is expensive. Use when you need to increase partitions or change partitioning keys.
- `.coalesce(N)`: Changes the number of partitions to `N` (where `N` must be *less* than the current number). **Avoids a full shuffle** by combining data on existing executors. Much more efficient if you only need to reduce partitions.
- `repartition(N, col("key"))`: Partitions data based on the hash of one or more key columns. This can drastically improve performance for subsequent joins or aggregations on those keys by colocating related data, **especially on large, skewed datasets**.

**Demonstration:**


In [0]:
# Get current number of partitions
num_partitions_before = regression_base_data.rdd.getNumPartitions()
print(f"Number of partitions before: {num_partitions_before}")

# Demonstrate repartition (e.g., doubling, minimum 10 for demo)
target_repartition = max(10, num_partitions_before * 2)
print(f"Demonstrating repartition to {target_repartition} partitions...")
repartitioned_data = regression_base_data.repartition(target_repartition)
num_partitions_after_repart = repartitioned_data.rdd.getNumPartitions()
print(f"Number of partitions after repartition: {num_partitions_after_repart}")
# Explain that this forced a shuffle (expensive on large data)

# Demonstrate coalesce (e.g., halving, minimum 1)
target_coalesce = max(1, num_partitions_after_repart // 2)
print(f"Demonstrating coalesce to {target_coalesce} partitions...")
coalesced_data = repartitioned_data.coalesce(target_coalesce)
num_partitions_after_coal = coalesced_data.rdd.getNumPartitions()
print(f"Number of partitions after coalesce: {num_partitions_after_coal}")
# Explain that coalesce is cheaper as it avoids a full shuffle

# Demonstrate repartition by key
print("Demonstrating repartition by key:")
# Useful if we frequently joined or grouped by customer_id on large data
partitioned_by_key_data = regression_base_data.repartition(target_repartition, col("customer_id"))
print(f"Created DataFrame repartitioned by 'customer_id'")


Number of partitions before: 1
Demonstrating repartition to 10 partitions...
Number of partitions after repartition: 10
Demonstrating coalesce to 5 partitions...
Number of partitions after coalesce: 5
Demonstrating repartition by key:
Created DataFrame repartitioned by 'customer_id'


## 4. Optimization Technique 3: Broadcast Joins

**Concept:** When joining a large DataFrame with a significantly smaller DataFrame, Spark can perform a *Broadcast Hash Join*. It sends (broadcasts) the entire small DataFrame to every executor node. The join then happens locally on each executor without requiring a massive shuffle of the large DataFrame's data across the network, which is typically required by other join strategies like Sort-Merge Join. This is often a huge performance win.

- Spark attempts to do this automatically based on the `spark.sql.autoBroadcastJoinThreshold` configuration (default is often 10MB). If a table is smaller than this threshold, it's usually broadcasted.
- You can verify which join strategy was used by examining the query plan in the Spark UI (SQL/Jobs tab).
- You can also *hint* to Spark to broadcast a specific DataFrame using the `broadcast()` function.

**Demonstration:**


In [0]:
# Reload data for join demonstration
interactions_large_df = spark.table("ecommerce.interactions")
customers_small_df = spark.table("ecommerce.customers")

# Demonstrate broadcast hint syntax
print("Demonstrating broadcast hint syntax...")
hinted_join_df = interactions_large_df.join(broadcast(customers_small_df), "customer_id")
hinted_join_df.count() # Action to trigger join execution
print("Hinted broadcast join complete. Check Spark plan should definitely show BroadcastHashJoin.")
# Explain: Useful if auto-broadcast doesn't trigger or for explicit control on larger datasets.

hinted_join_df.explain(mode="formatted")


Demonstrating broadcast hint syntax...
Hinted broadcast join complete. Check Spark plan should definitely show BroadcastHashJoin.
== Physical Plan ==
AdaptiveSparkPlan (10)
+- == Initial Plan ==
   Project (9)
   +- BroadcastHashJoin Inner BuildRight (8)
      :- Project (3)
      :  +- Filter (2)
      :     +- Scan parquet spark_catalog.ecommerce.interactions (1)
      +- Exchange (7)
         +- Project (6)
            +- Filter (5)
               +- Scan parquet spark_catalog.ecommerce.customers (4)


(1) Scan parquet spark_catalog.ecommerce.interactions
Output [10]: [customer_id#15842, product_id#15843L, timestamp#15844, interaction_type#15845, time_spent_seconds#15846L, purchase_amount#15847, user_rating#15848, device#15849, previous_visits#15850L, _databricks_internal_edge_computed_column_skip_row#15960]
Batched: true
Location: PreparedDeltaFileIndex [dbfs:/user/hive/warehouse/ecommerce.db/interactions]
PushedFilters: [IsNotNull(customer_id)]
ReadSchema: struct<customer_id:int,p

## 5. Conclusion

In this notebook, we demonstrated the syntax and underlying concepts for three key Spark performance optimization techniques:
1.  **Caching/Persistence:** Avoiding recomputation of expensive DataFrame operations.
2.  **Partitioning:** Controlling data distribution and parallelism using `repartition` and `coalesce`.
3.  **Broadcast Joins:** Efficiently joining large DataFrames with small ones by avoiding shuffles.

**Key Takeaway:** While the effects weren't dramatic on our small course dataset, understanding and applying these techniques is **essential** for building performant and scalable ML pipelines and data processing jobs on large, real-world datasets. Always use the **Spark UI** to monitor your job's execution, identify bottlenecks (like shuffles or stages taking too long), and verify if your optimizations are having the intended effect.

This concludes Module 4 and the PySpark MLlib Course! We hope you now have a solid foundation for building scalable machine learning applications with PySpark.
