In [0]:
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
import random

spark = SparkSession.builder.appName("Salting Example").getOrCreate()

In [0]:
# Sample data
data = [("A", 1), ("A", 2), ("A", 3), ("B", 1), ("B", 2), ("C", 1)]
df = spark.createDataFrame(data, ["key", "value"])
df.show()

+---+-----+
|key|value|
+---+-----+
|  A|    1|
|  A|    2|
|  A|    3|
|  B|    1|
|  B|    2|
|  C|    1|
+---+-----+



In [0]:
# Introduce salting
salted_df = (
    df.withColumn("salt", (F.rand() * 5).cast("int"))  # Create a salt column with values 0-4
    .withColumn("salted_key", F.concat(F.col("key"), F.lit("_"), F.col("salt")))  # Combine key and salt
)

salted_df.show()

+---+-----+----+----------+
|key|value|salt|salted_key|
+---+-----+----+----------+
|  A|    1|   3|       A_3|
|  A|    2|   4|       A_4|
|  A|    3|   4|       A_4|
|  B|    1|   1|       B_1|
|  B|    2|   3|       B_3|
|  C|    1|   0|       C_0|
+---+-----+----+----------+



In [0]:
# Perform the operation (e.g., group by)
result_df = (
    salted_df.groupBy("salted_key")
    .agg(F.sum("value").alias("sum_value"))
)

result_df.show()

+----------+---------+
|salted_key|sum_value|
+----------+---------+
|       A_3|        1|
|       A_4|        5|
|       B_1|        1|
|       B_3|        2|
|       C_0|        1|
+----------+---------+



In [0]:
# Remove salt to restore original keys (optional)
final_result = (
    result_df.withColumn("key", F.split(F.col("salted_key"), "_")[0])
    .select("key", "sum_value")
)

final_result.show()

+---+---------+
|key|sum_value|
+---+---------+
|  A|        1|
|  A|        5|
|  B|        1|
|  B|        2|
|  C|        1|
+---+---------+

