In [0]:
spark.conf.set("spark.sql.adaptive.enabled", "false")
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)

In [0]:
from pyspark.sql import functions as F

In [0]:
# data skew

In [0]:
left_data = [(1, f"left_{i}") for i in range(90000)] + [(i, f"left_{i}") for i in range(2, 1002)]
left_df = spark.createDataFrame(left_data, ["id", "left_value"])
display(left_df)

In [0]:
right_data = [(i, f"right_{i}") for i in range(1, 1002)]
right_df = spark.createDataFrame(right_data, ["id", "right_value"])
right_df.display()

In [0]:
def count_records_per_partition(df):
    df = df.rdd.mapPartitionsWithIndex(
        lambda idx, it: [(idx, sum(1 for _ in it))]
    ).toDF(["partitionId", "record_count"]).orderBy("partitionId")
    return df

In [0]:
count_records_per_partition(left_df).display()

In [0]:
left_df = left_df.repartition(4, "id")
count_records_per_partition(left_df).display()

In [0]:
left_df.rdd.getNumPartitions()

In [0]:
left_df.rdd.glom().collect()

In [0]:
right_df.rdd.getNumPartitions()

In [0]:
count_records_per_partition(right_df).display()

In [0]:
# Let's perform join on skewed data

skewed_join = (
    left_df
    .join(
        right_df,
        on=["id"],
        how="left"
    )
)

skewed_join.show()

In [0]:
salt_range = 10

left_df_salted = (
    left_df
    .withColumn(
        "salt",
        F.when(F.col("id") == 1, F.rand() * salt_range)
        .otherwise(F.lit(0))
        .cast("int")
    )
    .withColumn("join_key", F.concat_ws("_", F.col("id"), F.col("salt")))
)


right_df_explode = (
    right_df
    .withColumn("salt", F.explode(F.array([F.lit(i) for i in range(salt_range)])))
    .withColumn("join_key", F.concat_ws("_", F.col("id"), F.col("salt")))
)

In [0]:
count_records_per_partition(left_df_salted).display()

In [0]:
left_df_salted = left_df_salted.repartition(4, "join_key")
count_records_per_partition(left_df_salted).display()

In [0]:
# optional
right_df_salted = right_df_explode.repartition(4, "join_key")
count_records_per_partition(right_df_explode).display()

In [0]:
result_join = (
    left_df_salted
    .join(
        right_df_explode,
        on=["join_key"],
        how="left"
    )
)

result_join.show()