## Medallion_etl
 - Bronze -> Silver -> Gold PySpark ETL for Product domain

In [1]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.window import Window
from pyspark.sql.types import (
    IntegerType, LongType, DoubleType, StringType, DecimalType, TimestampType
)
import sys

### Configurations

In [2]:
BASE_CSV_PATH = "/home/jovyan/data/project_data"     
BASE_PATH = "/home/jovyan/data/lake"                 
BRONZE_PATH = f"{BASE_PATH}/bronze"
SILVER_PATH = f"{BASE_PATH}/silver"
GOLD_PATH = f"{BASE_PATH}/gold"

### Create Spark Session

In [3]:
spark = SparkSession.builder \
    .appName("MedallionETL_Product") \
    .config("spark.sql.shuffle.partitions", "8") \
    .getOrCreate()

spark.sparkContext.setLogLevel("WARN")

In [4]:
def read_csv(name, schema=None):
    path = f"{BASE_CSV_PATH}/{name}.csv"
    df = spark.read.option("header", True).option("multiLine", False).option("inferSchema", True).csv(path)
    return df

# ---------------------------
# BRONZE Layer: raw data 
# ---------------------------

In [5]:
def bronze_ingest():
    print("=== Bronze ingest ===")
    tables = ["products", "categories", "inventory", "reviews", "orders", "order_items", "suppliers", "product_suppliers"]
    for t in tables:
        df = read_csv(t)
        out = f"{BRONZE_PATH}/{t}"
        df.write.mode("overwrite").parquet(out)
        print(f"Wrote bronze table {t} -> {out} (rows={df.count()})")

In [6]:
bronze_products = spark.read.parquet("/home/jovyan/data/lake/bronze/products")
print("Bronze products count:", bronze_products.count())

Bronze products count: 1000


# ---------------------------
# SILVER Layer:
# -----------------------------

#### Load bronze

In [7]:

    products = spark.read.parquet(f"{BRONZE_PATH}/products")
    categories = spark.read.parquet(f"{BRONZE_PATH}/categories")
    inventory = spark.read.parquet(f"{BRONZE_PATH}/inventory")
    reviews = spark.read.parquet(f"{BRONZE_PATH}/reviews")
    orders = spark.read.parquet(f"{BRONZE_PATH}/orders")
    order_items = spark.read.parquet(f"{BRONZE_PATH}/order_items")
    suppliers = spark.read.parquet(f"{BRONZE_PATH}/suppliers")
    product_suppliers = spark.read.parquet(f"{BRONZE_PATH}/product_suppliers")


 #### 1) Products: cast types, enforce not-null sku, dedupe by sku keeping latest updated_at

In [8]:
  p = (products.withColumn("product_id", F.col("product_id").cast(LongType())) \
                .withColumn("sku", F.col("sku").cast(StringType())) \
                .withColumn("name", F.col("name").cast(StringType())) \
                .withColumn("price", F.col("price").cast(DecimalType(12,2))) \
                .withColumn("category_id", F.col("category_id").cast(LongType())) \
                .withColumn("created_at", F.to_timestamp("created_at")) \
                .withColumn("updated_at", F.to_timestamp("updated_at")))

 #### Filter out rows without SKU or name or price <= 0

In [9]:
    p = p.filter(F.col("sku").isNotNull() & F.col("name").isNotNull() & (F.col("price") > 0))

#### Deduplicate: keep max(updated_at) per sku

In [10]:
window_spec = Window.partitionBy("sku").orderBy(F.col("updated_at").desc())

# Add row_number and keep only first row per SKU
p = (p.withColumn("rn", F.row_number().over(window_spec))
       .filter(F.col("rn") == 1)
       .drop("rn"))

In [11]:
p.write.mode("overwrite").parquet(f"{SILVER_PATH}/products")
print("Wrote silver products:", p.count())

Wrote silver products: 1000


#### 2) Categories: clean names, ensure parent ids valid (null if not)

In [12]:
c = categories.withColumn("category_id", F.col("category_id").cast(LongType())) \
                  .withColumn("name", F.trim(F.col("name"))) \
                  .withColumn("parent_id", F.col("parent_id").cast(LongType())) \
                  .withColumn("created_at", F.to_timestamp("created_at"))
    
valid_parents = [r.category_id for r in c.select("category_id").collect()]
c = c.withColumn("parent_id", F.when(F.col("parent_id").isin(valid_parents), F.col("parent_id")).otherwise(F.lit(None)))
c.write.mode("overwrite").parquet(f"{SILVER_PATH}/categories")
print("Wrote silver categories:", c.count())

Wrote silver categories: 10


#### 3) Inventory: ensure stock_qty >= 0, cast types

In [13]:
inv = inventory.withColumn("inventory_id", F.col("inventory_id").cast(LongType())) \
                   .withColumn("product_id", F.col("product_id").cast(LongType())) \
                   .withColumn("warehouse_id", F.col("warehouse_id").cast(LongType())) \
                   .withColumn("stock_qty", F.coalesce(F.col("stock_qty").cast(IntegerType()), F.lit(0))) \
                   .withColumn("last_updated", F.to_timestamp("last_updated"))
inv = inv.filter(F.col("stock_qty") >= 0)
inv.write.mode("overwrite").parquet(f"{SILVER_PATH}/inventory")
print("Wrote silver inventory:", inv.count())

Wrote silver inventory: 1000


#### 4) Reviews: ensure rating 1..5

In [14]:
    rev = reviews.withColumn("review_id", F.col("review_id").cast(LongType())) \
                 .withColumn("product_id", F.col("product_id").cast(LongType())) \
                 .withColumn("user_id", F.col("user_id").cast(LongType())) \
                 .withColumn("rating", F.coalesce(F.col("rating").cast(IntegerType()), F.lit(0))) \
                 .withColumn("created_at", F.to_timestamp("created_at"))
    rev = rev.filter((F.col("rating") >= 1) & (F.col("rating") <= 5))
    rev.write.mode("overwrite").parquet(f"{SILVER_PATH}/reviews")
    print("Wrote silver reviews:", rev.count())

Wrote silver reviews: 5000


#### 5) Orders + order_items: cast + keep only orders referenced by order_items

In [15]:
    o = orders.withColumn("order_id", F.col("order_id").cast(LongType())) \
              .withColumn("user_id", F.col("user_id").cast(LongType())) \
              .withColumn("order_date", F.to_timestamp("order_date")) \
              .withColumn("total_amount", F.col("total_amount").cast(DecimalType(12,2)))
    oi = order_items.withColumn("order_item_id", F.col("order_item_id").cast(LongType())) \
                    .withColumn("order_id", F.col("order_id").cast(LongType())) \
                    .withColumn("product_id", F.col("product_id").cast(LongType())) \
                    .withColumn("quantity", F.col("quantity").cast(IntegerType())) \
                    .withColumn("price", F.col("price").cast(DecimalType(12,2)))
    # Filter order_items with valid product ids and orders
    valid_orders = [r.order_id for r in o.select("order_id").collect()]
    oi = oi.filter(F.col("order_id").isin(valid_orders))
    # Recompute order totals from items (optional): join and compute
    oi.write.mode("overwrite").parquet(f"{SILVER_PATH}/order_items")
    o.write.mode("overwrite").parquet(f"{SILVER_PATH}/orders")
    print("Wrote silver orders:", o.count(), "order_items:", oi.count())

Wrote silver orders: 2000 order_items: 5899


#### 6) Suppliers and product_suppliers

In [16]:
    s = suppliers.withColumn("supplier_id", F.col("supplier_id").cast(LongType())) \
                 .withColumn("created_at", F.to_timestamp("created_at"))
    ps = product_suppliers.withColumn("product_id", F.col("product_id").cast(LongType())) \
                          .withColumn("supplier_id", F.col("supplier_id").cast(LongType()))
    s.write.mode("overwrite").parquet(f"{SILVER_PATH}/suppliers")
    ps.write.mode("overwrite").parquet(f"{SILVER_PATH}/product_suppliers")
    print("Wrote silver suppliers:", s.count(), "product_suppliers:", ps.count())

Wrote silver suppliers: 50 product_suppliers: 1960


# ---------------------------
# GOLD Layer: Fact/Dim Table
# -------------------------------+--

In [17]:

    products = spark.read.parquet(f"{SILVER_PATH}/products")
    categories = spark.read.parquet(f"{SILVER_PATH}/categories")
    inventory = spark.read.parquet(f"{SILVER_PATH}/inventory")
    reviews = spark.read.parquet(f"{SILVER_PATH}/reviews")
    orders = spark.read.parquet(f"{SILVER_PATH}/orders")
    order_items = spark.read.parquet(f"{SILVER_PATH}/order_items")
    suppliers = spark.read.parquet(f"{SILVER_PATH}/suppliers")
    product_suppliers = spark.read.parquet(f"{SILVER_PATH}/product_suppliers")

#### DIM: dim_products (flatten product + category + supplier list)

In [18]:

p = products.alias("p")
c = categories.select(
        F.col("category_id").alias("cat_id"),
        F.col("name").alias("category_name")
    ).alias("c")

# Base join: products + categories
dim_products = (
    p.join(c, p.category_id == c.cat_id, "left")
     .drop("category_id")  # drop only from products
     .select(
         "p.product_id",
         "p.sku",
         "p.name",
         "p.description",
         "p.price",
         "p.currency",
         "p.status",
         "c.category_name",
         "p.created_at"
     )
)

# Join supplier_ids (grouped set)
prod_sup_join = product_suppliers.groupBy("product_id") \
    .agg(F.collect_set("supplier_id").alias("supplier_ids"))

dim_products = (
    dim_products.join(prod_sup_join, "product_id", "left")
                .withColumn("supplier_ids", F.coalesce(F.col("supplier_ids"), F.array()))
)

# Final column order
dim_products = dim_products.select(
    "product_id",
    "sku",
    "name",
    "description",
    "price",
    "currency",
    "status",
    "category_name",
    "supplier_ids",
    "created_at"
)

 #### write partitioned by status for quick filtering

In [19]:

    dim_products.write.mode("overwrite").partitionBy("status").parquet(f"{GOLD_PATH}/dim_products")
    print("Wrote dim_products:", dim_products.count())

Wrote dim_products: 1000


 #### FACT: fact_sales (aggregate order_items by product_id, date) prepared date dimension columns

In [20]:
oi = order_items.join(orders.select("order_id","order_date"), "order_id", "left") \
                    .withColumn("order_date", F.to_date("order_date")) \
                    .withColumn("sales_amount", F.expr("quantity * price"))
fact_sales = oi.groupBy("product_id", "order_date") \
                   .agg(F.sum("quantity").alias("qty_sold"),
                        F.sum("sales_amount").alias("sales_amount"),
                        F.countDistinct("order_id").alias("order_count"))

 #### partition by year/month for query pruning

In [21]:
    fact_sales = fact_sales.withColumn("year", F.year("order_date")) \
                           .withColumn("month", F.month("order_date"))
    fact_sales.write.mode("overwrite").partitionBy("year","month").parquet(f"{GOLD_PATH}/fact_sales")
    print("Wrote fact_sales:" , fact_sales.count())

Wrote fact_sales: 5826


#### FACT: fact_reviews (avg rating and counts)

In [22]:
    fact_reviews = reviews.groupBy("product_id") \
                          .agg(F.count("*").alias("review_count"),
                               F.round(F.avg("rating"),2).alias("avg_rating"))
    fact_reviews.write.mode("overwrite").parquet(f"{GOLD_PATH}/fact_reviews")
    print("Wrote fact_reviews:", fact_reviews.count())

Wrote fact_reviews: 990


#### FACT: fact_inventory (latest snapshot per product)

In [23]:

    w = Window.partitionBy("product_id").orderBy(F.col("last_updated").desc())
    inv_latest = inventory.withColumn("rn", F.row_number().over(w)).filter(F.col("rn") == 1).drop("rn")
    inv_snapshot = inv_latest.withColumn("snapshot_date", F.to_date("last_updated"))
    inv_snapshot.write.mode("overwrite").parquet(f"{GOLD_PATH}/fact_inventory")
    print("Wrote fact_inventory:", inv_snapshot.count())

Wrote fact_inventory: 1000


### Testing the flow

#### 1) Bronze checks

In [24]:
print("-- Bronze basic counts")
for t in ["products","orders","order_items","reviews","inventory"]:
    df = spark.read.parquet(f"{BRONZE_PATH}/{t}")
    print(t, "rows:", df.count())

-- Bronze basic counts
products rows: 1000
orders rows: 2000
order_items rows: 5899
reviews rows: 5000
inventory rows: 1000


#### 2) Silver checks (data quality)

In [25]:
    print("\n-- Silver data quality checks")
    products = spark.read.parquet(f"{SILVER_PATH}/products")
    # duplicates by sku?
    dup_count = products.groupBy("sku").count().filter("count > 1").count()
    print("duplicate SKUs in silver.products:", dup_count)
    null_skus = products.filter(F.col("sku").isNull()).count()
    print("null skus:", null_skus)
    negative_prices = products.filter(F.col("price") <= 0).count()
    print("non-positive prices:", negative_prices)


-- Silver data quality checks
duplicate SKUs in silver.products: 0
null skus: 0
non-positive prices: 0


 #### 3) Silver sample queries for performance
 

In [26]:
    print("\n-- Silver example queries (use caching for repeated runs)")
    products.createOrReplaceTempView("silver_products")
    inventory = spark.read.parquet(f"{SILVER_PATH}/inventory")
    inventory.createOrReplaceTempView("silver_inventory")


-- Silver example queries (use caching for repeated runs)


#### Query 1: Products with low stock (join + filter)

In [27]:

    qA = spark.sql("""
      SELECT p.product_id, p.sku, p.name, i.stock_qty
      FROM silver_products p
      JOIN silver_inventory i ON p.product_id = i.product_id
      WHERE i.stock_qty < 10
      ORDER BY i.stock_qty ASC
      LIMIT 50
    """)
    qA.show(10, truncate=False)

+----------+-------+-------------------+---------+
|product_id|sku    |name               |stock_qty|
+----------+-------+-------------------+---------+
|898       |SKU1897|Because Just       |0        |
|783       |SKU1782|Also Spend         |0        |
|817       |SKU1816|Thousand Firm      |0        |
|561       |SKU1560|Or His             |1        |
|730       |SKU1729|Space Purpose      |1        |
|217       |SKU1216|Half Store         |2        |
|621       |SKU1620|Degree Key         |2        |
|134       |SKU1133|Thought Staff      |2        |
|977       |SKU1976|Mission No         |2        |
|82        |SKU1081|Collection Specific|3        |
+----------+-------+-------------------+---------+
only showing top 10 rows



#### Query 2: Top 10 products by price (simple)

In [28]:
    qB = spark.sql("""
      SELECT sku, name, price FROM silver_products
      ORDER BY price DESC
      LIMIT 10
    """)
    qB.show(10, truncate=False)

+-------+----------------+------+
|sku    |name            |price |
+-------+----------------+------+
|SKU1531|Them Home       |499.86|
|SKU1847|North Forward   |498.45|
|SKU1532|Call Charge     |498.34|
|SKU1475|Still Style     |496.52|
|SKU1390|Foot Help       |495.30|
|SKU1261|School Receive  |495.08|
|SKU1705|Tell Media      |495.03|
|SKU1518|Hair Agency     |493.70|
|SKU1069|Knowledge Decade|493.51|
|SKU1441|Now Few         |493.39|
+-------+----------------+------+



#### 4) Gold query performance tests (aggregations & partition pruning)

In [29]:
    print("\n-- Gold queries (fact & dim)")


-- Gold queries (fact & dim)


#### Query 3: top selling products overall (join with dim_products)

In [30]:
    fact_sales = spark.read.parquet(f"{GOLD_PATH}/fact_sales")
    dim_products = spark.read.parquet(f"{GOLD_PATH}/dim_products")
    fact_sales.createOrReplaceTempView("gold_fact_sales")
    dim_products.createOrReplaceTempView("gold_dim_products")
    q1 = spark.sql("""
      SELECT d.product_id, d.sku, d.name, SUM(f.sales_amount) as total_sales, SUM(f.qty_sold) as total_qty
      FROM gold_fact_sales f
      JOIN gold_dim_products d ON f.product_id = d.product_id
      GROUP BY d.product_id, d.sku, d.name
      ORDER BY total_sales DESC
      LIMIT 20
    """)
    q1.show(20, truncate=False)

+----------+-------+-------------------+-----------+---------+
|product_id|sku    |name               |total_sales|total_qty|
+----------+-------+-------------------+-----------+---------+
|948       |SKU1947|Government Yourself|35139.78   |78       |
|44        |SKU1043|Trouble Simply     |35043.47   |77       |
|248       |SKU1247|Foot News          |33844.30   |70       |
|474       |SKU1473|Life Process       |30782.28   |69       |
|262       |SKU1261|School Receive     |28714.64   |58       |
|396       |SKU1395|Picture Mr         |27943.20   |60       |
|712       |SKU1711|Until System       |27585.28   |92       |
|306       |SKU1305|Guess Conference   |27330.80   |56       |
|229       |SKU1228|Enter Size         |26793.60   |60       |
|89        |SKU1088|South Necessary    |26206.03   |59       |
|273       |SKU1272|Remember Begin     |26052.00   |60       |
|442       |SKU1441|Now Few            |25656.28   |52       |
|52        |SKU1051|Light Sea          |25264.85   |65 

#### Query 4: avg rating per category (join fact_reviews -> dim_products -> categories)
    

In [31]:
    fact_reviews = spark.read.parquet(f"{GOLD_PATH}/fact_reviews")
    categories = spark.read.parquet(f"{SILVER_PATH}/categories")
    fact_reviews.createOrReplaceTempView("gold_fact_reviews")
    categories.createOrReplaceTempView("silver_categories")
    spark.sql("""
      SELECT c.name as category, ROUND(AVG(fr.avg_rating),2) as avg_rating, SUM(fr.review_count) as total_reviews
      FROM gold_fact_reviews fr
      JOIN gold_dim_products d ON fr.product_id = d.product_id
      JOIN silver_categories c ON d.category_name = c.name
      GROUP BY c.name
      ORDER BY avg_rating DESC
      LIMIT 20
    """).show(20,truncate=False)

+--------+----------+-------------+
|category|avg_rating|total_reviews|
+--------+----------+-------------+
|Capital |3.08      |523          |
|Hot     |3.02      |539          |
|Before  |3.02      |428          |
|Create  |2.98      |502          |
|Animal  |2.98      |470          |
|Force   |2.97      |430          |
|Wind    |2.95      |600          |
|Series  |2.92      |470          |
|Large   |2.9       |478          |
|Anyone  |2.89      |560          |
+--------+----------+-------------+



#### Query 5: filter by particular year/month
    

In [32]:
print("\n-- filter by particular year/month")
filtered_sales = fact_sales.filter((F.col("year") == 2025) & (F.col("month") == 9))
filtered_sales.show(20, truncate=False)


-- filter by particular year/month
+----------+----------+--------+------------+-----------+----+-----+
|product_id|order_date|qty_sold|sales_amount|order_count|year|month|
+----------+----------+--------+------------+-----------+----+-----+
|86        |2025-09-01|8       |1327.76     |1          |2025|9    |
|957       |2025-09-02|2       |523.72      |1          |2025|9    |
|349       |2025-09-02|7       |1333.01     |1          |2025|9    |
|23        |2025-09-05|7       |1047.27     |1          |2025|9    |
|717       |2025-09-03|7       |2913.12     |1          |2025|9    |
|167       |2025-09-03|3       |491.37      |1          |2025|9    |
|22        |2025-09-01|1       |74.05       |1          |2025|9    |
|306       |2025-09-05|10      |4880.50     |1          |2025|9    |
|635       |2025-09-06|3       |174.15      |1          |2025|9    |
|100       |2025-09-06|1       |58.41       |1          |2025|9    |
|295       |2025-09-04|8       |543.12      |1          |2025|9    

#### Query 6: Avg Rating per Category

In [33]:
    print("\n-- Avg Rating per Category")
    fact_reviews = spark.read.parquet("/home/jovyan/data/lake/gold/fact_reviews")

    avg_rating_cat = (fact_reviews.join(dim_products, "product_id")
                  .groupBy("category_name")
                  .agg(F.round(F.avg("avg_rating"), 2).alias("avg_rating"),
                       F.sum("review_count").alias("total_reviews"))
                  .orderBy(F.desc("avg_rating"))
                  .limit(20))
    avg_rating_cat.show(truncate=False)


-- Avg Rating per Category
+-------------+----------+-------------+
|category_name|avg_rating|total_reviews|
+-------------+----------+-------------+
|Capital      |3.08      |523          |
|Hot          |3.02      |539          |
|Before       |3.02      |428          |
|Create       |2.98      |502          |
|Animal       |2.98      |470          |
|Force        |2.97      |430          |
|Wind         |2.95      |600          |
|Series       |2.92      |470          |
|Large        |2.9       |478          |
|Anyone       |2.89      |560          |
+-------------+----------+-------------+



# ------------------------+
# Main flow               |
# ------------------------+

In [34]:

if __name__ == "__main__":
    bronze_ingest()
    print("=== silver_transform() ===")
    print("=== gold_build() ===")
    print("ETL completed.")
    spark.stop()

=== Bronze ingest ===
Wrote bronze table products -> /home/jovyan/data/lake/bronze/products (rows=1000)
Wrote bronze table categories -> /home/jovyan/data/lake/bronze/categories (rows=10)
Wrote bronze table inventory -> /home/jovyan/data/lake/bronze/inventory (rows=1000)
Wrote bronze table reviews -> /home/jovyan/data/lake/bronze/reviews (rows=5000)
Wrote bronze table orders -> /home/jovyan/data/lake/bronze/orders (rows=2000)
Wrote bronze table order_items -> /home/jovyan/data/lake/bronze/order_items (rows=5899)
Wrote bronze table suppliers -> /home/jovyan/data/lake/bronze/suppliers (rows=50)
Wrote bronze table product_suppliers -> /home/jovyan/data/lake/bronze/product_suppliers (rows=1960)
=== silver_transform() ===
=== gold_build() ===
ETL completed.
