In [10]:
# Step 1: Set up PySpark Session
from pyspark.sql import SparkSession
# Initialize Spark session
spark = SparkSession.builder \
    .appName("DataSkewHandling") \
    .getOrCreate()

25/02/14 01:22:44 WARN Utils: Your hostname, codespaces-976928 resolves to a loopback address: 127.0.0.1; using 10.0.2.76 instead (on interface eth0)
25/02/14 01:22:44 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/02/14 01:22:45 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [11]:
# Step 2: Create a Sample DataFrame with Skew
from pyspark.sql import functions as F
# Create a DataFrame with skewed data
data = [(1, "A")] * 1000 + [(2, "B")] * 100 + [(3, "C")] * 10
df = spark.createDataFrame(data, ["id", "category"])
# Show the DataFrame
print("Sample DataFrame:")
df.show(5)


Sample DataFrame:


                                                                                

+---+--------+
| id|category|
+---+--------+
|  1|       A|
|  1|       A|
|  1|       A|
|  1|       A|
|  1|       A|
+---+--------+
only showing top 5 rows



In [12]:
# Step 3: Diagnose Data Skew
# Check the number of rows per partition
print("\nNumber of rows per partition:")
df.groupBy(F.spark_partition_id()).count().show()
# Inspect data distribution in partitions
print("\nData in partitions (first 2 rows per partition):")
partitions = df.rdd.glom().collect()
for i, partition in enumerate(partitions):
    print(f"Partition {i}: {partition[:2]}")



Number of rows per partition:
+--------------------+-----+
|SPARK_PARTITION_ID()|count|
+--------------------+-----+
|                   0|  555|
|                   1|  555|
+--------------------+-----+


Data in partitions (first 2 rows per partition):
Partition 0: [Row(id=1, category='A'), Row(id=1, category='A')]
Partition 1: [Row(id=1, category='A'), Row(id=1, category='A')]


In [16]:
# Check distribution of 'id' across partitions
df.withColumn("partition_id", F.spark_partition_id()) \
  .groupBy("partition_id", "id") \
  .count() \
  .orderBy("partition_id", "id") \
  .show()

+------------+---+-----+
|partition_id| id|count|
+------------+---+-----+
|           0|  1|  555|
|           1|  1|  445|
|           1|  2|  100|
|           1|  3|   10|
+------------+---+-----+



# Repartitioning by column

We can use df.repartition function with number of partitionsand the columns on which we want to partition. We can ensure that data is evenly distributed across workers reducing skew and improving performance.

`df = df. repartition (<n_partitions>, '‹col_1>', '‹col_2>',・・・）`

In [14]:
# Step 5: Handle Data Skew - Repartition by Column
# Repartition the DataFrame by the skewed column
print("\nRepartitioning by 'id' column...")
df_repartitioned = df.repartition("id")
# Check the new distribution
print("\nNumber of rows per partition after repartitioning:")
df_repartitioned.groupBy(F.spark_partition_id()).count().show()



Repartitioning by 'id' column...

Number of rows per partition after repartitioning:


25/02/14 01:23:00 WARN GarbageCollectionMetrics: To enable non-built-in garbage collector(s) List(G1 Concurrent GC), users should configure it(them) to spark.eventLog.gcMetrics.youngGenerationGarbageCollectors or spark.eventLog.gcMetrics.oldGenerationGarbageCollectors


+--------------------+-----+
|SPARK_PARTITION_ID()|count|
+--------------------+-----+
|                   0| 1110|
+--------------------+-----+



# Repartition using salt
```
import pyspark.sql. functions as F
# Add a 'salt' column with a random value for each row
df = df.withColumn ('salt', F. rand ())
# Repartition the DataFrame into 8 partitions based on the
'salt' column
df = df. repartition (8, 'salt')


Salting involves adding a random value to your data to distribute it more evenly.`F.rand` function assigns a random float between zero and oneto each row in the data frame.This introduces randomness to distribute skewed keysmore evenly across partitions.

In [15]:
# Step 6: Handle Data Skew - Salting
# Add a salt column to evenly distribute data
print("\nAdding a salt column for even distribution...")
df_salted = df.withColumn("salt", F.rand())
# Repartition by the salt column
df_salted = df_salted.repartition(8, "salt")
# Check the new distribution
print("\nNumber of rows per partition after salting:")
df_salted.groupBy(F.spark_partition_id()).count().show()



Adding a salt column for even distribution...

Number of rows per partition after salting:
+--------------------+-----+
|SPARK_PARTITION_ID()|count|
+--------------------+-----+
|                   0|  123|
|                   1|  127|
|                   2|  129|
|                   3|  146|
|                   4|  133|
|                   5|  133|
|                   6|  167|
|                   7|  152|
+--------------------+-----+

