In [None]:

from pyspark.sql import SparkSession
from pyspark.sql.functions import col, concat, lit, rand

# Initialize Spark session
spark = SparkSession.builder.appName("SaltingExample").getOrCreate()

# Sample data creation
data = [
    ("key1", "value1"),
    ("key1", "value2"),
    ("key1", "value3"),
    ("key2", "value4"),
    ("key3", "value5"),
    ("key3", "value6"),
    ("key3", "value7"),
    ("key3", "value8"),
]

# Create DataFrame
df = spark.createDataFrame(data, ["key", "value"])

# Show initial data
print("Initial Data:")
df.show()

# Check the number of partitions
print(f"Initial number of partitions: {df.rdd.getNumPartitions()}")

# Salting process
# Add a salt to the key to distribute data more evenly
salted_df = df.withColumn("salted_key", concat(col("key"), lit("_"), (rand() * 10).cast("int")))

# Show salted data
print("Salted Data:")
salted_df.show()

# Check the number of partitions after salting
print(f"Number of partitions after salting: {salted_df.rdd.getNumPartitions()}")

# Perform a group by operation on the salted key
grouped_df = salted_df.groupBy("salted_key").count()

# Show the result of the group by operation
print("Grouped Data:")
grouped_df.show()



: 

In [3]:
# Stop the Spark session
spark.stop()