In [1]:
# Spark Session
from pyspark.sql import SparkSession

spark = (
    SparkSession
    .builder
    .appName("Optimizing Skewness and Spillage")
    # .master("spark://192.168.1.12:7077")
    .master("local[*]")  # Use local mode for development
    .config("spark.cores.max", 8)
    .config("spark.executor.cores", 4)
    .config("spark.executor.memory", "512M")
    .getOrCreate()
)

spark


In [2]:
# Disable AQE and Broadcast join

spark.conf.set("spark.sql.adaptive.enabled", False)
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", False)
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)

In [3]:
# Read EMP CSV data

_schema = "first_name string, last_name string, job_title string, dob string, email string, phone string, salary double, department_id int"

emp = spark.read.format("csv").schema(_schema).option("header", True).load(r"/content/employee_records.csv")

In [4]:
# Read DEPT CSV data

_dept_schema = "department_id int, department_name string, description string, city string, state string, country string"

dept = spark.read.format("csv").schema(_dept_schema).option("header", True).load(r"/content/department_data.csv")

In [5]:
# Join Datasets

df_joined = emp.join(dept, on=emp.department_id==dept.department_id, how="left_outer")


In [6]:
df_joined.write.format("noop").mode("overwrite").save()

In [7]:
#Explain Plan

df_joined.explain()

== Physical Plan ==
*(4) SortMergeJoin [department_id#7], [department_id#16], LeftOuter
:- *(1) Sort [department_id#7 ASC NULLS FIRST], false, 0
:  +- Exchange hashpartitioning(department_id#7, 200), ENSURE_REQUIREMENTS, [plan_id=70]
:     +- FileScan csv [first_name#0,last_name#1,job_title#2,dob#3,email#4,phone#5,salary#6,department_id#7] Batched: false, DataFilters: [], Format: CSV, Location: InMemoryFileIndex(1 paths)[file:/content/employee_records.csv], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<first_name:string,last_name:string,job_title:string,dob:string,email:string,phone:string,s...
+- *(3) Sort [department_id#16 ASC NULLS FIRST], false, 0
   +- Exchange hashpartitioning(department_id#16, 200), ENSURE_REQUIREMENTS, [plan_id=82]
      +- *(2) Filter isnotnull(department_id#16)
         +- FileScan csv [department_id#16,department_name#17,description#18,city#19,state#20,country#21] Batched: false, DataFilters: [isnotnull(department_id#16)], Format: CSV, Location

In [8]:
# Check the partition details to understand distribution
from pyspark.sql.functions import spark_partition_id, count, lit

part_df = df_joined.withColumn("partition_num", spark_partition_id()).groupBy("partition_num").agg(count(lit(1)).alias("count"))

part_df.show()

+-------------+------+
|partition_num| count|
+-------------+------+
|           42|473632|
+-------------+------+



In [9]:
# Verify Employee data based on department_id
from pyspark.sql.functions import count, lit, desc, col

emp.groupBy("department_id").agg(count(lit(1))).show()

+-------------+--------+
|department_id|count(1)|
+-------------+--------+
|         NULL|  473632|
+-------------+--------+



In [21]:
# Set shuffle partitions to a lesser number - 16

spark.conf.set("spark.sql.shuffle.partitions", 32)

In [22]:
# Let prepare the salt
import random
from pyspark.sql.functions import udf

# UDF to return a random number every time and add to Employee as salt
@udf
def salt_udf():
    return random.randint(0, 32)

# Salt Data Frame to add to department
salt_df = spark.range(0, 32)
salt_df.show()

+---+
| id|
+---+
|  0|
|  1|
|  2|
|  3|
|  4|
|  5|
|  6|
|  7|
|  8|
|  9|
| 10|
| 11|
| 12|
| 13|
| 14|
| 15|
| 16|
| 17|
| 18|
| 19|
+---+
only showing top 20 rows



In [23]:
# Salted Employee
from pyspark.sql.functions import lit, concat

salted_emp = emp.withColumn("salted_dept_id", concat("department_id", lit("_"), salt_udf()))

salted_emp.show()

+----------+----------+--------------------+----------+--------------------+--------------------+--------+-------------+--------------+
|first_name| last_name|           job_title|       dob|               email|               phone|  salary|department_id|salted_dept_id|
+----------+----------+--------------------+----------+--------------------+--------------------+--------+-------------+--------------+
|   Richard|  Morrison|Public relations ...|1973-05-05|melissagarcia@exa...|       (699)525-4827|512653.0|         NULL|          NULL|
|     Bobby|  Mccarthy|   Barrister's clerk|1974-04-25|   llara@example.net|  (750)846-1602x7458|999836.0|         NULL|          NULL|
|    Dennis|    Norman|Land/geomatics su...|1990-06-24| jturner@example.net|    873.820.0518x825|131900.0|         NULL|          NULL|
|      John|    Monroe|        Retail buyer|1968-06-16|  erik33@example.net|    820-813-0557x624|485506.0|         NULL|          NULL|
|  Michelle|   Elliott|      Air cabin crew|1975

In [24]:
# Salted Department

salted_dept = dept.join(salt_df, how="cross").withColumn("salted_dept_id", concat("department_id", lit("_"), "id"))

salted_dept.where("department_id = 9").show()

+-------------+--------------------+--------------------+-----------+-----+-------+---+--------------+
|department_id|     department_name|         description|       city|state|country| id|salted_dept_id|
+-------------+--------------------+--------------------+-----------+-----+-------+---+--------------+
|            9|Mcmahon, Terrell ...|De-engineered hig...|Marychester|   MN|  Italy|  0|           9_0|
|            9|Mcmahon, Terrell ...|De-engineered hig...|Marychester|   MN|  Italy|  1|           9_1|
|            9|Mcmahon, Terrell ...|De-engineered hig...|Marychester|   MN|  Italy|  2|           9_2|
|            9|Mcmahon, Terrell ...|De-engineered hig...|Marychester|   MN|  Italy|  3|           9_3|
|            9|Mcmahon, Terrell ...|De-engineered hig...|Marychester|   MN|  Italy|  4|           9_4|
|            9|Mcmahon, Terrell ...|De-engineered hig...|Marychester|   MN|  Italy|  5|           9_5|
|            9|Mcmahon, Terrell ...|De-engineered hig...|Marychester|   M

In [25]:
# Lets make the salted join now

salted_joined_df = salted_emp.join(salted_dept, on=salted_emp.salted_dept_id==salted_dept.salted_dept_id, how="left_outer")

In [26]:
salted_joined_df.write.format("noop").mode("overwrite").save()

In [27]:
# Check the partition details to understand distribution
from pyspark.sql.functions import spark_partition_id, count

part_df = salted_joined_df.withColumn("partition_num", spark_partition_id()).groupBy("partition_num").agg(count(lit(1)).alias("count"))

part_df.show()

+-------------+------+
|partition_num| count|
+-------------+------+
|           10|473632|
+-------------+------+

