In [11]:
# ============================================
# 0. Imports & Spark session
# ============================================

import time
import builtins  # <-- IMPORTANT
from pyspark.sql import SparkSession
from pyspark.sql.functions import (
    avg,
    round as spark_round,   # Spark round ONLY for Columns
    count,
    col,
    sum as _sum
)

spark = (
    SparkSession.builder
    .appName("PostgresVsSparkBenchmark")
    .config("spark.jars.packages", "org.postgresql:postgresql:42.7.2")
    .config("spark.eventLog.enabled", "true")
    .config("spark.eventLog.dir", "/tmp/spark-events")
    .config("spark.history.fs.logDirectory", "/tmp/spark-events")
    .config("spark.sql.shuffle.partitions", "4")
    .config("spark.default.parallelism", "4")
    .getOrCreate()
)

spark.sparkContext.setLogLevel("WARN")


In [12]:
# ============================================
# 1. JDBC connection config
# ============================================

jdbc_url = "jdbc:postgresql://postgres:5432/postgres"
jdbc_props = {
    "user": "postgres",
    "password": "postgres",
    "driver": "org.postgresql.Driver"
}

In [13]:
# ============================================
# 2. Load data from PostgreSQL
# ============================================

print("\n=== Loading people_big from PostgreSQL ===")

start = time.time()

df_big = spark.read.jdbc(
    url=jdbc_url,
    table="people_big",
    properties=jdbc_props
)

# Force materialization
row_count = df_big.count()

print(f"Rows loaded: {row_count}")
print("Load time:", builtins.round(time.time() - start, 2), "seconds")

# Register temp view
df_big.createOrReplaceTempView("people_big")



=== Loading people_big from PostgreSQL ===
Rows loaded: 1000000
Load time: 0.49 seconds


In [14]:
# ============================================
# 3. Query (a): Simple aggregation
# ============================================

print("\n=== Query (a): AVG salary per department ===")

start = time.time()

q_a = (
    df_big
    .groupBy("department")
    .agg(spark_round(avg("salary"), 2).alias("avg_salary"))
    .orderBy("department", ascending=False)
    .limit(10)
)

q_a.collect()
q_a.show(truncate=False)
print("Query (a) time:", builtins.round(time.time() - start, 2), "seconds")



=== Query (a): AVG salary per department ===
+------------------+----------+
|department        |avg_salary|
+------------------+----------+
|Workforce Planning|85090.82  |
|Web Development   |84814.36  |
|UX Design         |84821.2   |
|UI Design         |85164.64  |
|Treasury          |84783.27  |
|Training          |85148.1   |
|Tax               |85018.57  |
|Sustainability    |85178.99  |
|Supply Chain      |84952.89  |
|Subscriptions     |84899.19  |
+------------------+----------+

Query (a) time: 1.57 seconds


In [15]:
# ============================================
# 4. Query (b): Nested aggregation
# ============================================

print("\n=== Query (b): Nested aggregation ===")

start = time.time()

q_b = spark.sql("""
SELECT country, AVG(avg_salary) AS avg_salary
FROM (
    SELECT country, department, AVG(salary) AS avg_salary
    FROM people_big
    GROUP BY country, department
) sub
GROUP BY country
ORDER BY avg_salary DESC
LIMIT 10
""")

q_b.collect()
q_b.show(truncate=False)
print("Query (b) time:", builtins.round(time.time() - start, 2), "seconds")



=== Query (b): Nested aggregation ===
+------------+-----------------+
|country     |avg_salary       |
+------------+-----------------+
|Egypt       |87382.229633112  |
|Kuwait      |87349.3517377211 |
|Saudi Arabia|87348.80512175433|
|Panama      |87345.00623707911|
|Denmark     |87328.03514120901|
|Jamaica     |87305.437352083  |
|Lebanon     |87292.76891750695|
|Turkey      |87290.69043798617|
|Malaysia    |87253.78746341489|
|Kazakhstan  |87251.74274968785|
+------------+-----------------+

Query (b) time: 1.23 seconds


In [16]:
# ============================================
# 5. Query (c): Sorting + Top-N
# ============================================

print("\n=== Query (c): Top 10 salaries ===")

start = time.time()

q_c = (
    df_big
    .orderBy(col("salary").desc())
    .limit(10)
)

q_c.collect()
q_c.show(truncate=False)
print("Query (c) time:", builtins.round(time.time() - start, 2), "seconds")



=== Query (c): Top 10 salaries ===
+------+----------+---------+------+----------------+------+------------+
|id    |first_name|last_name|gender|department      |salary|country     |
+------+----------+---------+------+----------------+------+------------+
|764650|Tim       |Jensen   |Male  |Analytics       |160000|Bulgaria    |
|10016 |Anastasia |Edwards  |Female|Analytics       |159998|Kuwait      |
|754528|Adrian    |Young    |Male  |Game Analytics  |159997|UK          |
|240511|Diego     |Lopez    |Male  |Game Analytics  |159995|Malaysia    |
|893472|Mariana   |Cook     |Female|People Analytics|159995|South Africa|
|359891|Mariana   |Novak    |Female|Game Analytics  |159992|Mexico      |
|53102 |Felix     |Taylor   |Male  |Data Science    |159989|Bosnia      |
|768143|Teresa    |Campbell |Female|Game Analytics  |159988|Spain       |
|729165|Antonio   |Weber    |Male  |Analytics       |159987|Moldova     |
|952549|Adrian    |Harris   |Male  |Analytics       |159986|Georgia     |
+-

In [17]:
# ============================================
# 6. Query (d): Heavy self-join (COUNT only)
# ============================================

print("\n=== Query (d): Heavy self-join COUNT (DANGEROUS) ===")

start = time.time()

q_d = (
    df_big.alias("p1")
    .join(df_big.alias("p2"), on="country")
    .count()
)

print("Join count:", q_d)
print("Query (d) time:", builtins.round(time.time() - start, 2), "seconds")



=== Query (d): Heavy self-join COUNT (DANGEROUS) ===
Join count: 10983941260
Query (d) time: 8.4 seconds


In [18]:

# ============================================
# 7. Query (d-safe): Join-equivalent rewrite
# ============================================

print("\n=== Query (d-safe): Join-equivalent rewrite ===")

start = time.time()

grouped = df_big.groupBy("country").agg(count("*").alias("cnt"))

q_d_safe = grouped.select(
    _sum(col("cnt") * col("cnt")).alias("total_pairs")
)

q_d_safe.collect()
q_d_safe.show()
print("Query (d-safe) time:", builtins.round(time.time() - start, 2), "seconds")



=== Query (d-safe): Join-equivalent rewrite ===
+-----------+
|total_pairs|
+-----------+
|10983941260|
+-----------+

Query (d-safe) time: 1.0 seconds


In [19]:
# ============================================
# 8. Load ecommerce orders_big
# ============================================

print("\n=== Loading orders_big from PostgreSQL ===")

start = time.time()

orders_df = spark.read.jdbc(
    url=jdbc_url,
    table="orders_big",
    properties=jdbc_props
)

# Force materialization
orders_count = orders_df.count()

print(f"Orders rows loaded: {orders_count}")
print("Load time:", builtins.round(time.time() - start, 2), "seconds")

# Register temp view
orders_df.createOrReplaceTempView("orders_big")



=== Loading orders_big from PostgreSQL ===
Orders rows loaded: 1000000
Load time: 0.68 seconds


In [20]:
# ============================================
# 9. Query (A): Highest price_per_category
# ============================================

print("\n=== Query (A): Highest price_per_category ===")

start = time.time()

q_ecom_a = (
    orders_df
    .groupBy("product_category")
    .agg(
        spark_round(_sum(col("price_per_unit") * col("quantity")), 2).alias("price_per_category")
    )
    .orderBy(col("price_per_category").desc())
    .limit(1)
)

q_ecom_a.collect()
q_ecom_a.show(truncate=False)
print("Query (A) time:", builtins.round(time.time() - start, 2), "seconds")



=== Query (A): Highest price_per_category ===
+----------------+------------------+
|product_category|price_per_category|
+----------------+------------------+
|Automotive      |306589798.86      |
+----------------+------------------+

Query (A) time: 2.34 seconds


In [21]:
# ============================================
# 10. Query (B): Top 3 categories by quantity
# ============================================

print("\n=== Query (B): Top categories by quantity ===")

start = time.time()

q_ecom_b = (
    orders_df
    .groupBy("product_category")
    .agg(_sum(col("quantity")).alias("total_quantity_sold"))
    .orderBy(col("total_quantity_sold").desc())
    .limit(3)
)

q_ecom_b.collect()
q_ecom_b.show(truncate=False)
print("Query (B) time:", builtins.round(time.time() - start, 2), "seconds")



=== Query (B): Top categories by quantity ===
+----------------+-------------------+
|product_category|total_quantity_sold|
+----------------+-------------------+
|Health & Beauty |300842             |
|Electronics     |300804             |
|Toys            |300598             |
+----------------+-------------------+

Query (B) time: 0.92 seconds


In [22]:
# ============================================
# 11. Query (C): Revenue per product category
# ============================================

print("\n=== Query (C): Revenue per product category ===")

start = time.time()

q_ecom_c = (
    orders_df
    .groupBy("product_category")
    .agg(spark_round(_sum(col("price_per_unit") * col("quantity")), 2).alias("revenue"))
    .orderBy(col("revenue").desc())
)

q_ecom_c.collect()
q_ecom_c.show(truncate=False)
print("Query (C) time:", builtins.round(time.time() - start, 2), "seconds")



=== Query (C): Revenue per product category ===
+----------------+------------+
|product_category|revenue     |
+----------------+------------+
|Automotive      |306589798.86|
|Electronics     |241525009.45|
|Home & Garden   |78023780.09 |
|Sports          |61848990.83 |
|Health & Beauty |46599817.89 |
|Office Supplies |38276061.64 |
|Fashion         |31566368.22 |
|Toys            |23271039.02 |
|Grocery         |15268355.66 |
|Books           |12731976.04 |
+----------------+------------+

Query (C) time: 2.18 seconds


In [23]:
# ============================================
# 12. Query (D): Top customers by total spending
# ============================================

print("\n=== Query (D): Top customers by total spending ===")

start = time.time()

q_ecom_d = (
    orders_df
    .groupBy("customer_name")
    .agg(spark_round(_sum(col("price_per_unit") * col("quantity")), 2).alias("total_spending"))
    .orderBy(col("total_spending").desc())
    .limit(10)
)

q_ecom_d.collect()
q_ecom_d.show(truncate=False)
print("Query (D) time:", builtins.round(time.time() - start, 2), "seconds")



=== Query (D): Top customers by total spending ===
+--------------+--------------+
|customer_name |total_spending|
+--------------+--------------+
|Carol Taylor  |991179.18     |
|Nina Lopez    |975444.95     |
|Daniel Jackson|959344.48     |
|Carol Lewis   |947708.57     |
|Daniel Young  |946030.14     |
|Alice Martinez|935100.02     |
|Ethan Perez   |934841.24     |
|Leo Lee       |934796.48     |
|Eve Young     |933176.86     |
|Ivy Rodriguez |925742.64     |
+--------------+--------------+

Query (D) time: 2.43 seconds


In [24]:
# ============================================
# 13. Cleanup
# ============================================

spark.stop()
