# PySpark Joins and Performance Optimization

## Overview
This notebook covers all types of joins in PySpark, union operations, and performance optimization techniques.

## Learning Objectives
- Master all join types (inner, outer, left, right, cross, semi, anti)
- Understand join strategies and performance
- Use broadcast joins effectively
- Perform union and union all operations
- Optimize DataFrame operations
- Handle data skew

---

## 1. Join Types

In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *

# Sample data - Customers
customers_data = [
    (1, "Alice", "NY"),
    (2, "Bob", "CA"),
    (3, "Charlie", "TX"),
    (4, "Diana", "FL")
]

customers = spark.createDataFrame(
    customers_data,
    ["customer_id", "name", "state"]
)

# Sample data - Orders
orders_data = [
    (101, 1, 150.0, "2024-01-15"),
    (102, 1, 200.0, "2024-01-16"),
    (103, 2, 75.0, "2024-01-17"),
    (104, 3, 300.0, "2024-01-18"),
    (105, 5, 120.0, "2024-01-19")  # customer_id 5 doesn't exist
]

orders = spark.createDataFrame(
    orders_data,
    ["order_id", "customer_id", "amount", "order_date"]
)

print("Customers:")
display(customers)

print("\nOrders:")
display(orders)

### Inner Join
Returns only matching rows from both DataFrames.

In [None]:
# Inner join - only matching records
inner_join = customers.join(
    orders,
    customers.customer_id == orders.customer_id,
    "inner"
).select(
    customers.customer_id,
    customers.name,
    orders.order_id,
    orders.amount
)

print("Inner Join Result:")
display(inner_join)

### Left Outer Join
Returns all rows from left DataFrame, with nulls for non-matching right rows.

In [None]:
# Left join - all customers, even without orders
left_join = customers.join(
    orders,
    customers.customer_id == orders.customer_id,
    "left"
).select(
    customers.customer_id,
    customers.name,
    orders.order_id,
    orders.amount
)

print("Left Join Result (all customers):")
display(left_join)

### Right Outer Join
Returns all rows from right DataFrame, with nulls for non-matching left rows.

In [None]:
# Right join - all orders, even without matching customer
right_join = customers.join(
    orders,
    customers.customer_id == orders.customer_id,
    "right"
).select(
    orders.order_id,
    orders.customer_id.alias("order_customer_id"),
    customers.customer_id.alias("cust_id"),
    customers.name,
    orders.amount
)

print("Right Join Result (all orders):")
display(right_join)

### Full Outer Join
Returns all rows from both DataFrames, with nulls where there's no match.

In [None]:
# Full outer join - all customers and all orders
full_join = customers.join(
    orders,
    customers.customer_id == orders.customer_id,
    "full_outer"
).select(
    coalesce(customers.customer_id, orders.customer_id).alias("customer_id"),
    customers.name,
    orders.order_id,
    orders.amount
)

print("Full Outer Join Result:")
display(full_join)

### Cross Join
Cartesian product - every row from left with every row from right.

In [None]:
# Cross join - Cartesian product (use carefully!)
cross_join = customers.crossJoin(orders).select(
    customers.customer_id.alias("cust_id"),
    customers.name,
    orders.order_id,
    orders.customer_id.alias("order_cust_id")
)

print(f"Cross Join Result ({customers.count()} × {orders.count()} = {cross_join.count()} rows):")
display(cross_join.limit(10))

### Left Semi Join
Returns rows from left DataFrame that have a match in right (like EXISTS in SQL).

In [None]:
# Left semi join - customers who have placed orders
semi_join = customers.join(
    orders,
    customers.customer_id == orders.customer_id,
    "left_semi"
)

print("Left Semi Join (customers with orders):")
display(semi_join)

### Left Anti Join
Returns rows from left DataFrame that DON'T have a match in right (like NOT EXISTS).

In [None]:
# Left anti join - customers without orders
anti_join = customers.join(
    orders,
    customers.customer_id == orders.customer_id,
    "left_anti"
)

print("Left Anti Join (customers without orders):")
display(anti_join)

## 2. Multiple Join Conditions

In [None]:
# Sample data with multiple join keys
products_data = [
    (1, "A", "Product 1"),
    (1, "B", "Product 2"),
    (2, "A", "Product 3"),
    (2, "B", "Product 4")
]

products = spark.createDataFrame(
    products_data,
    ["category_id", "region", "product_name"]
)

sales_data = [
    (1, "A", 100),
    (1, "B", 150),
    (2, "A", 200)
]

sales = spark.createDataFrame(
    sales_data,
    ["category_id", "region", "sales_amount"]
)

# Join on multiple conditions
multi_join = products.join(
    sales,
    (products.category_id == sales.category_id) & 
    (products.region == sales.region),
    "inner"
).select(
    products.product_name,
    products.region,
    sales.sales_amount
)

print("Multi-condition Join:")
display(multi_join)

## 3. Broadcast Joins (Performance Optimization)

In [None]:
from pyspark.sql.functions import broadcast

# When one DataFrame is small (<10MB), use broadcast join
# This avoids shuffling the large DataFrame

broadcast_join = orders.join(
    broadcast(customers),  # Broadcast small table
    "customer_id",
    "inner"
).select(
    customers.name,
    orders.order_id,
    orders.amount
)

print("Broadcast Join (optimized):")
display(broadcast_join)

# Check execution plan
print("\nExecution Plan:")
broadcast_join.explain()

## 4. Self Joins

In [None]:
# Employee-Manager relationship
employees_data = [
    (1, "Alice", None),
    (2, "Bob", 1),
    (3, "Charlie", 1),
    (4, "Diana", 2),
    (5, "Eve", 2)
]

employees = spark.createDataFrame(
    employees_data,
    ["emp_id", "emp_name", "manager_id"]
)

# Self join to get manager names
emp_alias = employees.alias("emp")
mgr_alias = employees.alias("mgr")

self_join = emp_alias.join(
    mgr_alias,
    col("emp.manager_id") == col("mgr.emp_id"),
    "left"
).select(
    col("emp.emp_name").alias("employee"),
    col("mgr.emp_name").alias("manager")
)

print("Self Join (Employee-Manager):")
display(self_join)

## 5. Union Operations

In [None]:
# Sample data - Q1 and Q2 sales
q1_sales = spark.createDataFrame(
    [(1, "Product A", 100), (2, "Product B", 150)],
    ["id", "product", "amount"]
)

q2_sales = spark.createDataFrame(
    [(2, "Product B", 150), (3, "Product C", 200)],  # Note: duplicate row
    ["id", "product", "amount"]
)

# Union (removes duplicates) - deprecated, use distinct after unionAll
union_result = q1_sales.union(q2_sales)
print(f"Union (with duplicates): {union_result.count()} rows")
display(union_result)

# UnionAll (keeps duplicates) - same as union in Spark 3.x
union_all_result = q1_sales.unionAll(q2_sales)
print(f"\nUnion All: {union_all_result.count()} rows")
display(union_all_result)

# Remove duplicates
distinct_result = union_all_result.distinct()
print(f"\nDistinct after Union: {distinct_result.count()} rows")
display(distinct_result)

### Union by Name

In [None]:
# DataFrames with different column orders
df1 = spark.createDataFrame(
    [(1, "A", 100)],
    ["id", "name", "value"]
)

df2 = spark.createDataFrame(
    [(200, 2, "B")],
    ["value", "id", "name"]  # Different order!
)

# unionByName - matches by column name, not position
union_by_name = df1.unionByName(df2)
print("Union By Name (handles different column orders):")
display(union_by_name)

## 6. Performance Optimization Techniques

### Caching

In [None]:
# Cache DataFrame for reuse
customers_cached = customers.cache()

# Use cached DataFrame multiple times
count1 = customers_cached.count()
count2 = customers_cached.filter(col("state") == "NY").count()

# Unpersist when done
customers_cached.unpersist()

print(f"Total customers: {count1}")
print(f"NY customers: {count2}")

### Repartitioning

In [None]:
# Check current partitions
print(f"Current partitions: {orders.rdd.getNumPartitions()}")

# Repartition (increases or decreases partitions with shuffle)
orders_repart = orders.repartition(4)
print(f"After repartition: {orders_repart.rdd.getNumPartitions()}")

# Repartition by column (for better data locality)
orders_by_customer = orders.repartition(4, "customer_id")
print(f"Repartitioned by customer_id: {orders_by_customer.rdd.getNumPartitions()}")

# Coalesce (only decreases partitions, no shuffle)
orders_coal = orders.coalesce(2)
print(f"After coalesce: {orders_coal.rdd.getNumPartitions()}")

### Filter Pushdown

In [None]:
# Good: Filter before join (reduces data to join)
high_value_orders = orders.filter(col("amount") > 100)
result_optimized = customers.join(high_value_orders, "customer_id")

# Less efficient: Filter after join
result_unoptimized = customers.join(orders, "customer_id").filter(col("amount") > 100)

print("Filter before join (optimized):")
display(result_optimized)

# Both produce same result, but first is more efficient

### Column Pruning

In [None]:
# Good: Select only needed columns early
customers_subset = customers.select("customer_id", "name")
result = orders.join(customers_subset, "customer_id")

# Less efficient: Select all columns, then filter
# (Spark optimizer often handles this, but good practice)

## 7. Handling Data Skew

In [None]:
# Data skew: when one key has many more values than others

# Technique 1: Salt the key
# Add random suffix to distribute skewed key

skewed_df = orders.withColumn(
    "salted_key",
    concat(col("customer_id"), lit("_"), (rand() * 10).cast("int"))
)

# Technique 2: Broadcast the smaller table
# Already covered above

# Technique 3: Use AQE (Adaptive Query Execution)
# spark.conf.set("spark.sql.adaptive.enabled", "true")
# spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")

print("Salted keys for skew handling:")
display(skewed_df.select("customer_id", "salted_key"))

## 8. Explain Plans

In [None]:
# View physical plan
join_query = customers.join(orders, "customer_id")

print("Physical Plan:")
join_query.explain()

print("\n" + "="*50)
print("Extended Explanation:")
join_query.explain(extended=True)

## Practice Exercises

### Exercise 1: Customer Purchase Analysis
Find customers who placed orders in both Q1 and Q2 (use semi join or intersect).

In [None]:
# Your solution here
# TODO: Create Q1 and Q2 order datasets, find customers in both

### Exercise 2: Optimize a Join
Given a large orders table and small products table, write an optimized join.

In [None]:
# Your solution here
# TODO: Use broadcast join for optimal performance

## Summary

In this notebook, you learned:

✅ All join types (inner, left, right, full, cross, semi, anti)
✅ Multiple join conditions
✅ Broadcast joins for performance
✅ Self joins
✅ Union and unionByName operations
✅ Performance optimization (caching, repartitioning, filter pushdown)
✅ Handling data skew
✅ Reading explain plans

## Next Steps

1. Practice with larger datasets
2. Monitor query plans and optimize
3. Learn about Adaptive Query Execution (AQE)
4. Study join strategies in depth

## Additional Resources

- [Spark Join Strategies](https://sparkbyexamples.com/spark/spark-sql-join-types/)
- [Performance Tuning Guide](https://spark.apache.org/docs/latest/sql-performance-tuning.html)
- [Join Performance](https://sparkbyexamples.com/pyspark-tutorial/)