In [3]:
from pyspark.sql import SparkSession

In [4]:
spark = SparkSession.builder.appName("pyspark optimization techniques").getOrCreate()

In [28]:
spark.conf.set("spark.sql.optimizer.dynamicPartitionPruning.enabled","true")  #enabling dpp

Dynamic partiton pruning

In [10]:
from pyspark.sql.functions import col
sales = spark.range(0, 100000).withColumn("region", (col("id") % 10).cast("string"))
sales.write.partitionBy("region").mode("overwrite").parquet("/tmp/sales_data")

# Read partitioned data
sales_df = spark.read.parquet("/tmp/sales_data")

# Create filter table
targets = spark.createDataFrame([("1",), ("2",)], ["region"])

In [11]:
sales_df.alias("s").join(targets.alias("t"), "region").select("id").show()

+-----+
|   id|
+-----+
|49991|
|49981|
|49971|
|49961|
|49951|
|49941|
|49931|
|49921|
|49911|
|49901|
|49891|
|49881|
|49871|
|49861|
|49851|
|49841|
|49831|
|49821|
|49811|
|49801|
+-----+
only showing top 20 rows



#Predicate pushdown

filter pushdown and predicate pushdown are same terms in pyspark, same effects

In [20]:
from pyspark.sql.functions import col
import time

df = spark.range(0,10_000_000).withColumn("amount", (col("id")*2))
df.write.mode("overwrite").parquet("/tmp.parquet_data")

start = time.time()

df_p = spark.read.parquet("/tmp.parquet_data")
df_p.filter(col("amount")>19_999_000).show()

print("time with predicate pushown will be:", round(time.time()-start,2),"sec")


#project pushdown
df = spark.read.parquet("/tmp.parquet_data")
df.select("amount", "id").show()

+-------+--------+
|     id|  amount|
+-------+--------+
|9999501|19999002|
|9999502|19999004|
|9999503|19999006|
|9999504|19999008|
|9999505|19999010|
|9999506|19999012|
|9999507|19999014|
|9999508|19999016|
|9999509|19999018|
|9999510|19999020|
|9999511|19999022|
|9999512|19999024|
|9999513|19999026|
|9999514|19999028|
|9999515|19999030|
|9999516|19999032|
|9999517|19999034|
|9999518|19999036|
|9999519|19999038|
|9999520|19999040|
+-------+--------+
only showing top 20 rows

time with predicate pushown will be: 0.76 sec
+--------+-------+
|  amount|     id|
+--------+-------+
|10000000|5000000|
|10000002|5000001|
|10000004|5000002|
|10000006|5000003|
|10000008|5000004|
|10000010|5000005|
|10000012|5000006|
|10000014|5000007|
|10000016|5000008|
|10000018|5000009|
|10000020|5000010|
|10000022|5000011|
|10000024|5000012|
|10000026|5000013|
|10000028|5000014|
|10000030|5000015|
|10000032|5000016|
|10000034|5000017|
|10000036|5000018|
|10000038|5000019|
+--------+-------+
only showing top

#5M Customers 10k Products 10M transactions    Data generation time

In [21]:
from pyspark.sql.functions import col, expr, rand
from pyspark.sql.types import IntegerType, FloatType, StringType
import time

# 5M Customers
customers = spark.range(1, 5_000_001).withColumnRenamed("id", "customer_id")

# 10k Products with categories and prices
products = spark.range(1, 10_001).withColumnRenamed("id", "product_id") \
    .withColumn("category", expr("concat('Category_', product_id % 20)")) \
    .withColumn("price", (rand() * 100 + 1).cast("float"))

# 10M Transactions
transactions = spark.range(1, 10_000_001).withColumnRenamed("id", "txn_id") \
    .withColumn("customer_id", expr("cast(rand() * 5000000 + 1 as long)")) \
    .withColumn("product_id", expr("cast(rand() * 10000 + 1 as long)")) \
    .withColumn("txn_amount", (rand() * 200 + 1).cast("float")) \
    .withColumn("txn_date", expr("date_add('2020-01-01', cast(rand() * 1000 as int))"))

In [22]:
start = time.time()

customers.write.mode("overwrite").parquet("/content/customers_parquet")
products.write.mode("overwrite").parquet("/content/products_parquet")
transactions.write.mode("overwrite").parquet("/content/transactions_parquet")

print("✅ Data generation complete in", round(time.time() - start, 2), "seconds")

✅ Data generation complete in 13.35 seconds


In [24]:
df = spark.read.parquet("/content/transactions_parquet")
df.show(15)

+-------+-----------+----------+----------+----------+
| txn_id|customer_id|product_id|txn_amount|  txn_date|
+-------+-----------+----------+----------+----------+
|5000001|    2360511|      7498| 159.32028|2021-06-30|
|5000002|    2677145|      7864| 3.5967855|2020-01-01|
|5000003|    1757886|      9583| 1.6146764|2021-03-31|
|5000004|    3448918|      1379| 170.29713|2020-04-20|
|5000005|     717033|      3445| 170.64154|2020-06-04|
|5000006|    4810544|      4824| 56.250557|2022-08-09|
|5000007|    4587612|      1799| 177.81328|2021-07-05|
|5000008|    1912982|      7795| 100.88956|2022-04-26|
|5000009|    1819159|       114| 153.16435|2022-04-01|
|5000010|    3689421|      1364| 165.82265|2020-04-21|
|5000011|    1498691|      6662|105.492714|2021-07-29|
|5000012|    3087565|      4258|  186.7567|2021-06-15|
|5000013|     729633|      8765|  51.58265|2020-10-26|
|5000014|    4508528|      5995| 181.63846|2020-03-14|
|5000015|    2696518|      3199| 93.597984|2022-09-15|
+-------+-

In [26]:
transactions = spark.read.parquet("/content/transactions_parquet")

sorted_txns = transactions.orderBy("txn_amount")

sorted_txns.show(10)

+-------+-----------+----------+----------+----------+
| txn_id|customer_id|product_id|txn_amount|  txn_date|
+-------+-----------+----------+----------+----------+
|1042120|    1097813|      5960| 1.0000049|2020-03-24|
|9761303|    3985205|      4798| 1.0000216|2020-08-12|
|5996180|    4921044|      9159|  1.000022|2021-04-01|
|2226240|    3895459|      3447| 1.0000298|2020-02-24|
|2780339|    1305272|      5288| 1.0000516|2020-12-18|
|8139010|    1695054|      9159| 1.0000608|2021-08-14|
|1091272|    4268923|      5756| 1.0000753|2022-05-29|
|1269259|      81727|       715| 1.0000994|2022-06-20|
|7894708|     387144|      9071| 1.0001372|2021-12-21|
|3875477|    2349743|      2759| 1.0001428|2021-05-19|
+-------+-----------+----------+----------+----------+
only showing top 10 rows



WSCG Whole stage code generation

In [32]:
spark.conf.get("spark.sql.codegen.wholeStage")

'true'

In [27]:
#wholestage code generation checking

from pyspark.sql.functions import col

df = spark.range(1, 1_000_000).withColumn("value", col("id") * 2)

# Run some transformations
filtered = df.filter(col("value") > 1000).select("value")


filtered.explain()

== Physical Plan ==
*(1) Project [(id#457L * 2) AS value#459L]
+- *(1) Filter ((id#457L * 2) > 1000)
   +- *(1) Range (1, 1000000, step=1, splits=2)




how to enable AQE ADaptive query execution

In [30]:
spark = SparkSession.builder \
    .appName("AQE_Off_Example") \
    .config("spark.sql.adaptive.enabled", "false") \
    .getOrCreate()


how to enable kryo serilization

In [29]:
spark = SparkSession.builder \
    .appName("KryoExample") \
    .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \
    .getOrCreate()

Avoiding UDFs (or using pandas UDFs where necessary)

In [33]:
from pyspark.sql.functions import pandas_udf
import pandas as pd

In [35]:
@pandas_udf("int")
def square_pandas(s: pd.Series) -> pd.Series:
    return s * s

df.withColumn("squared", square_pandas(col("value"))).show()

+---+-----+-------+
| id|value|squared|
+---+-----+-------+
|  1|    2|      4|
|  2|    4|     16|
|  3|    6|     36|
|  4|    8|     64|
|  5|   10|    100|
|  6|   12|    144|
|  7|   14|    196|
|  8|   16|    256|
|  9|   18|    324|
| 10|   20|    400|
| 11|   22|    484|
| 12|   24|    576|
| 13|   26|    676|
| 14|   28|    784|
| 15|   30|    900|
| 16|   32|   1024|
| 17|   34|   1156|
| 18|   36|   1296|
| 19|   38|   1444|
| 20|   40|   1600|
+---+-----+-------+
only showing top 20 rows



Tuning spark.sql.shuffle.partitions

In [36]:
#first method
spark = SparkSession.builder \
    .appName("ShufflePartitionTuning") \
    .config("spark.sql.shuffle.partitions", "100") \
    .getOrCreate()

In [37]:
#second method will be to set it dynamically at run time
spark.conf.set("spark.sql.shuffle.partitions", 50)

Avoiding collect() on large datasets

In [42]:
# BAD: Will collect all 10 million rows into driver's memory
df_large = spark.range(1, 10_000_000)

df_large.show(5)                # Just displays first 5 rows/best technique

+---+
| id|
+---+
|  1|
|  2|
|  3|
|  4|
|  5|
+---+
only showing top 5 rows



Reuse of DataFrames or intermediate results

In [44]:
#cache
df_filtered = df.filter(col("value") > 1000).cache()

In [46]:
#persist
df_filtered = df.filter(col("value") > 1000).persist()

Writing Data in Optimal Partition Size (~128MB) (can be simulated)

In [None]:
#we will use repartiton in order to dicide into small files instead of suing large file like ion GBs or eg 300Mb