In [1]:
# Import modules
from pyspark.sql import SparkSession
from pyspark.sql.functions import broadcast, col, rand, skewness,lit

In [3]:
spark = (SparkSession.builder
         .appName("optimize-join-strategies")
         .master("spark://spark-master:7077")
         .config("spark.executor.memory", "512m")
         .getOrCreate())

spark.sparkContext.setLogLevel("ERROR")

In [4]:
# Create some sample data frames
# A large data frame with 10 million rows and two columns: id and value
large_df = spark.range(0, 1000000).withColumn("value", rand(seed=42))

# A small data frame with 10000 rows and two columns: id and name
small_df = spark.range(0, 10000).withColumn("name", col("id").cast("string"))

# A skewed data frame with 10 million rows and two columns: id and value
# The id column has a Zipf distribution with a skewness of 4.7
skewed_df = spark.range(0, 1000000).withColumn("value", rand(seed=42)).withColumn("id", col("id") ** 4)

In [5]:
# Define a function to measure the execution time of a query
import time

def measure_time(query):
    start = time.time()
    query.collect() # Force the query execution by calling an action
    end = time.time()
    print(f"Execution time: {end - start} seconds")

## Choosing the right join type

In [10]:
# Join large_df and small_df using an inner join on id column
measure_time(large_df.join(small_df, "id"))

# Join large_df and small_df using a left outer join on id column
measure_time(large_df.join(small_df, "id", "left"))

# Join large_df and small_df using a right outer join on id column
measure_time(large_df.join(small_df, "id", "right"))

# Join large_df and small_df using a full outer join on id column
measure_time(large_df.join(small_df, "id", "full"))

# Join large_df and small_df using a left semi join on id column
measure_time(large_df.join(small_df, "id", "left_semi"))

# Join large_df and small_df using a left anti join on id column
measure_time(large_df.join(small_df, "id", "left_anti"))

Execution time: 1.0291528701782227 seconds


                                                                                

Execution time: 25.628353357315063 seconds


                                                                                

Execution time: 5.767467021942139 seconds


                                                                                

Execution time: 18.269603490829468 seconds
Execution time: 0.5185227394104004 seconds
Execution time: 6.6788036823272705 seconds


## Broadcasting small tables

In [16]:
# Join large_df and small_df using an inner join with broadcast hash join hint
spark.conf.set("spark.sql.adaptive.enabled", "false")
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)

# Join large_df and small_df using an inner join without broadcasting
measure_time(large_df.join(small_df, "id"))

# Join large_df and small_df using an inner join with broadcasting
measure_time(large_df.join(broadcast(small_df), "id"))

                                                                                

Execution time: 3.1970551013946533 seconds
Execution time: 0.20557928085327148 seconds


## Using Join Hints

In [17]:
# Join large_df and small_df using an inner join with broadcast hash join hint
inner_join_broadcast_hint = large_df.hint("broadcast").join(small_df, "id")
measure_time(inner_join_broadcast_hint)

# Join large_df and small_df using an inner join with shuffle hash join hint
inner_join_shuffle_hash_hint = large_df.hint("shuffle_hash").join(small_df, "id")
measure_time(inner_join_shuffle_hash_hint)

# Join large_df and small_df using an inner join with shuffle replicate nested loop join hint
inner_join_shuffle_replicate_nl_hint = large_df.hint("shuffle_replicate_nl").join(small_df, "id")
measure_time(inner_join_shuffle_replicate_nl_hint)

# Join large_df and small_df using an inner join with sort merge join hint
inner_join_merge_hint = large_df.hint("merge").join(small_df, "id")
measure_time(inner_join_merge_hint)

Execution time: 1.8980967998504639 seconds


                                                                                

Execution time: 2.253967046737671 seconds


                                                                                

Execution time: 761.0224421024323 seconds




Execution time: 2.357747793197632 seconds


                                                                                

## Enable Adaptive Query Execution

In [19]:
# Join large_df and skewed_df using an inner join without AQE
spark.conf.set("spark.sql.adaptive.enabled", "false")
inner_join_no_aqe = large_df.join(skewed_df, "id")
measure_time(inner_join_no_aqe)

# Join large_df and skewed_df using an inner join with AQE
spark.conf.set("spark.sql.adaptive.enabled", "true")
inner_join_aqe = large_df.join(skewed_df, "id")
measure_time(inner_join_aqe)

                                                                                

Execution time: 8.188302278518677 seconds




Execution time: 2.7499380111694336 seconds


                                                                                

## Handling skewed data

### Salting

In [20]:
# Join large_df and skewed_df using an inner join with salting
# Add a salt column to the skewed_df with 10 random values
skewed_df_with_salt = skewed_df.withColumn("salt", (rand(seed=42) * 10).cast("int"))

# Join large_df and skewed_df_with_salt on id and salt columns
salted_join = large_df.withColumn("salt", lit(0)).join(skewed_df_with_salt, ["id", "salt"])

# Remove the salt column and self-join on id column
salted_join_no_salt = salted_join.drop("salt").join(skewed_df.select("id"), "id")
measure_time(salted_join_no_salt)



Execution time: 6.502509117126465 seconds


                                                                                

### Repartitioning

In [21]:
# Join large_df and skewed_df using an inner join with repartitioning
# Repartition the skewed_df into 1000 partitions
skewed_df_repartitioned = skewed_df.repartition(1000, "id")

# Join large_df and skewed_df_repartitioned on id column
repartitioned_join = large_df.join(skewed_df_repartitioned, "id")
measure_time(repartitioned_join)



Execution time: 10.696999549865723 seconds


                                                                                

In [25]:
spark.stop()