In [1]:
import os
import sys

project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

print("Project root:", project_root)


Project root: /Users/ruchita/data_engineering_projects/data_engineering


In [2]:
from src.utils import create_spark_session, get_logger, get_path
from pyspark.sql.functions import col
from pyspark.sql import DataFrame

spark = create_spark_session(
    os.path.join(project_root, "configs", "spark_config.yaml")
)

logger = get_logger("performance-optimizations")
logger.info("Spark session started")


Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
26/02/05 02:38:09 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
26/02/05 02:38:10 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.
26/02/05 02:38:10 WARN Utils: Service 'SparkUI' could not bind on port 4041. Attempting port 4042.
26/02/05 02:38:10 WARN Utils: Service 'SparkUI' could not bind on port 4042. Attempting port 4043.
26/02/05 02:38:10 WARN Utils: Service 'SparkUI' could not bind on port 4043. Attempting port 4044.
2026-02-05 02:38:10,355 - INFO - performance-optimizations - Spark session started


In [3]:
silver_enriched_path = get_path(
    os.path.join(project_root, "data", "processed"),
    "silver",
    "transactions_enriched"
)

df = spark.read.parquet(silver_enriched_path)

logger.info(f"Total records: {df.count()}")
df.show(5)


2026-02-05 02:38:21,789 - INFO - performance-optimizations - Total records: 20000
[Stage 4:>                                                          (0 + 1) / 1]

+----------+-----------+--------------+--------+----------------+--------+-------------------+---------------------+-------+--------------------+----------+----------------+---------+------------+---------+------------+----------+-----------------+
|account_id|customer_id|transaction_id|  amount|transaction_type|merchant|transaction_country|transaction_timestamp| status|       customer_name|       dob|customer_country|  segment|account_type|  balance|created_date|  txn_date|is_high_value_txn|
+----------+-----------+--------------+--------+----------------+--------+-------------------+---------------------+-------+--------------------+----------+----------------+---------+------------+---------+------------+----------+-----------------+
| ACC001179|  CUST00316|   TXN00000001|20087.89|             ATM|  Amazon|                UAE|  2025-12-19 00:40:38|SUCCESS|          Craig Dunn|1996-05-07|              US|Corporate|     Savings|174771.68|  2024-09-03|2025-12-19|                1|
| AC

                                                                                

In [4]:
print("Execution plan BEFORE optimizations:")
df.groupBy("customer_id").count().explain(True)


Execution plan BEFORE optimizations:
== Parsed Logical Plan ==
'Aggregate ['customer_id], ['customer_id, count(1) AS count#152L]
+- Relation [account_id#0,customer_id#1,transaction_id#2,amount#3,transaction_type#4,merchant#5,transaction_country#6,transaction_timestamp#7,status#8,customer_name#9,dob#10,customer_country#11,segment#12,account_type#13,balance#14,created_date#15,txn_date#16,is_high_value_txn#17] parquet

== Analyzed Logical Plan ==
customer_id: string, count: bigint
Aggregate [customer_id#1], [customer_id#1, count(1) AS count#152L]
+- Relation [account_id#0,customer_id#1,transaction_id#2,amount#3,transaction_type#4,merchant#5,transaction_country#6,transaction_timestamp#7,status#8,customer_name#9,dob#10,customer_country#11,segment#12,account_type#13,balance#14,created_date#15,txn_date#16,is_high_value_txn#17] parquet

== Optimized Logical Plan ==
Aggregate [customer_id#1], [customer_id#1, count(1) AS count#152L]
+- Project [customer_id#1]
   +- Relation [account_id#0,custome

In [5]:
print("Initial partitions:", df.rdd.getNumPartitions())


Initial partitions: 1


In [6]:
df_repartitioned = df.repartition("txn_date")

print("Partitions after repartition:", df_repartitioned.rdd.getNumPartitions())


Partitions after repartition: 2


In [7]:
df_repartitioned.cache()
df_repartitioned.count()  # materialize cache


20000

In [8]:
print("Execution plan AFTER optimizations:")
df_repartitioned.groupBy("customer_id").count().explain(True)


Execution plan AFTER optimizations:
== Parsed Logical Plan ==
'Aggregate ['customer_id], ['customer_id, count(1) AS count#739L]
+- RepartitionByExpression [txn_date#16]
   +- Relation [account_id#0,customer_id#1,transaction_id#2,amount#3,transaction_type#4,merchant#5,transaction_country#6,transaction_timestamp#7,status#8,customer_name#9,dob#10,customer_country#11,segment#12,account_type#13,balance#14,created_date#15,txn_date#16,is_high_value_txn#17] parquet

== Analyzed Logical Plan ==
customer_id: string, count: bigint
Aggregate [customer_id#1], [customer_id#1, count(1) AS count#739L]
+- RepartitionByExpression [txn_date#16]
   +- Relation [account_id#0,customer_id#1,transaction_id#2,amount#3,transaction_type#4,merchant#5,transaction_country#6,transaction_timestamp#7,status#8,customer_name#9,dob#10,customer_country#11,segment#12,account_type#13,balance#14,created_date#15,txn_date#16,is_high_value_txn#17] parquet

== Optimized Logical Plan ==
Aggregate [customer_id#1], [customer_id#1, 

In [9]:
customers_path = os.path.join(project_root, "data", "raw", "customers.csv")
customers_df = spark.read.csv(customers_path, header=True, inferSchema=True)

optimized_join_df = df_repartitioned.join(
    customers_df.hint("broadcast"),
    on="customer_id",
    how="left"
)

optimized_join_df.explain(True)


== Parsed Logical Plan ==
'Join UsingJoin(LeftOuter, [customer_id])
:- RepartitionByExpression [txn_date#16]
:  +- Relation [account_id#0,customer_id#1,transaction_id#2,amount#3,transaction_type#4,merchant#5,transaction_country#6,transaction_timestamp#7,status#8,customer_name#9,dob#10,customer_country#11,segment#12,account_type#13,balance#14,created_date#15,txn_date#16,is_high_value_txn#17] parquet
+- ResolvedHint (strategy=broadcast)
   +- Relation [customer_id#1031,customer_name#1032,dob#1033,country#1034,segment#1035] csv

== Analyzed Logical Plan ==
customer_id: string, account_id: string, transaction_id: string, amount: double, transaction_type: string, merchant: string, transaction_country: string, transaction_timestamp: timestamp, status: string, customer_name: string, dob: date, customer_country: string, segment: string, account_type: string, balance: double, created_date: date, txn_date: date, is_high_value_txn: int, customer_name: string, dob: date, country: string, segment: 

In [10]:
spark.conf.set("spark.sql.adaptive.enabled", "true")
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", "true")

df_repartitioned.groupBy("txn_date").sum("amount").explain(True)


== Parsed Logical Plan ==
'Aggregate ['txn_date], ['txn_date, sum(amount#3) AS sum(amount)#1353]
+- RepartitionByExpression [txn_date#16]
   +- Relation [account_id#0,customer_id#1,transaction_id#2,amount#3,transaction_type#4,merchant#5,transaction_country#6,transaction_timestamp#7,status#8,customer_name#9,dob#10,customer_country#11,segment#12,account_type#13,balance#14,created_date#15,txn_date#16,is_high_value_txn#17] parquet

== Analyzed Logical Plan ==
txn_date: date, sum(amount): double
Aggregate [txn_date#16], [txn_date#16, sum(amount#3) AS sum(amount)#1353]
+- RepartitionByExpression [txn_date#16]
   +- Relation [account_id#0,customer_id#1,transaction_id#2,amount#3,transaction_type#4,merchant#5,transaction_country#6,transaction_timestamp#7,status#8,customer_name#9,dob#10,customer_country#11,segment#12,account_type#13,balance#14,created_date#15,txn_date#16,is_high_value_txn#17] parquet

== Optimized Logical Plan ==
Aggregate [txn_date#16], [txn_date#16, sum(amount#3) AS sum(amount

## Small Files Problem
- Caused by excessive partitioning
- Solved using repartition/coalesce
- Impacts query latency and metadata load


In [11]:
df_coalesced = df_repartitioned.coalesce(10)
print("Partitions after coalesce:", df_coalesced.rdd.getNumPartitions())


Partitions after coalesce: 10


In [12]:
df_repartitioned.unpersist()
logger.info("Cache cleared")


2026-02-05 02:41:45,466 - INFO - performance-optimizations - Cache cleared
