## Parameters

## Imports and Setup

## Dimension Table Creation

### Customers Dimension

### Days Dimension

In [4]:
# COMMAND ----------
# ✨ Parameters ───────────────────────────────────────────────────────────────
num_customers = 100  # rows in customer dimension
num_products = 10  # rows in product dimension
num_facts = 100_000_000  # total rows across all years (will be split evenly among months for the chosen year)
start_ts = "2013-01-01 00:00:00"  # baseline start (used to infer full range but overridden by load_year)
load_year = 2017  # year to generate data for (must be between 2013 and 2022)
drop_tables = False  # set to True to drop fact and agg tables before loading
months = list(range(1, 13))
# ----------------------------------------------------------------------------

StatementMeta(, 88c09812-4909-4bfc-bb57-9cb8fb8d592e, 6, Finished, Available, Finished)

In [2]:
from datetime import datetime

from pyspark.sql import SparkSession
from pyspark.sql.functions import (
    col,
    concat_ws,
    dayofmonth,
    expr,
    from_unixtime,
    lit,
)

spark = SparkSession.builder.getOrCreate()

StatementMeta(, 88c09812-4909-4bfc-bb57-9cb8fb8d592e, 4, Finished, Available, Finished)

## Fact & Aggregate Table Loading

In [21]:
def _bucket(col_name: str, num_buckets: int, labels):
    """Return a Spark SQL expression that maps modulus buckets to text labels."""
    assert len(labels) == num_buckets, "labels must match num_buckets"
    cases = (
        "CASE "
        + " ".join(
            [
                f"WHEN {col_name} % {num_buckets} = {i} THEN '{labels[i]}'"
                for i in range(num_buckets)
            ]
        )
        + " END"
    )
    return expr(cases)


# ---------------------------------------------------------------------------

# COMMAND ----------
# 🏷️  Dimension Table 1 – Customers

df_customers = (
    spark.range(1, num_customers + 1)
    .withColumnRenamed("id", "customer_id")
    .withColumn("customer_name", concat_ws(" ", lit("Customer"), expr("customer_id")))
    .withColumn("region", _bucket("customer_id", 4, ["North", "South", "East", "West"]))
)

(df_customers.write.mode("overwrite").format("delta").saveAsTable("dim_customer"))

# ---------------------------------------------------------------------------

# COMMAND ----------
# 🏷️  Dimension Table 2 – Days


days = list(range(1, 32))
df_day = spark.createDataFrame([(d,) for d in days], ["day"]).withColumn(
    "day_label", concat_ws("", lit("Day "), col("day"))
)

(df_day.write.mode("overwrite").format("delta").saveAsTable("dim_day"))

# ---------------------------------------------------------------------------

# COMMAND ----------
# 🏷️  Dimension Table 3 – Months


month_names = [
    "Jan",
    "Feb",
    "Mar",
    "Apr",
    "May",
    "Jun",
    "Jul",
    "Aug",
    "Sep",
    "Oct",
    "Nov",
    "Dec",
]
df_month = spark.createDataFrame(
    [(m, month_names[m - 1]) for m in months], ["month", "month_name"]
).withColumn("month_label", concat_ws(" ", lit("Month"), col("month_name")))

(df_month.write.mode("overwrite").format("delta").saveAsTable("dim_month"))

# ---------------------------------------------------------------------------

# COMMAND ----------
# 🏷️  Dimension Table 4 – Years

years_range = list(range(2013, 2023))  # 10 years
df_year = spark.createDataFrame([(y,) for y in years_range], ["year"]).withColumn(
    "year_label", concat_ws("", lit("Year "), col("year"))
)

(df_year.write.mode("overwrite").format("delta").saveAsTable("dim_year"))

# ---------------------------------------------------------------------------

# COMMAND ----------
# 🏷️  Dimension Table 5 – Products

df_products = (
    spark.range(1, num_products + 1)
    .withColumnRenamed("id", "product_id")
    .withColumn("product_name", concat_ws(" ", lit("Product"), expr("product_id")))
    .withColumn(
        "category",
        _bucket(
            "product_id",
            5,
            ["Accessories", "Hardware", "Software", "Services", "Other"],
        ),
    )
)

(df_products.write.mode("overwrite").format("delta").saveAsTable("dim_product"))

# ---------------------------------------------------------------------------

StatementMeta(, 51afabbe-d8f0-45de-92e6-64767a52e1b3, 23, Finished, Available, Finished)

In [6]:
# COMMAND ----------
# 📊  Fact & Aggregate Tables Loading

# Optionally drop existing tables
def drop_existing_tables():
    spark.sql("DROP TABLE IF EXISTS fact_transactions")
    spark.sql("DROP TABLE IF EXISTS agg_transactions")


if drop_tables:
    drop_existing_tables()
    print("Dropped existing fact_transactions and agg_transactions tables.")
else:
    print("Skipping drop of existing tables.")

# Determine facts per month only for load_year
total_months = 12  # months in selected year
facts_per_month = num_facts // total_months

# Generate and write partitions month-by-month for load_year
for m in months:
    # Calculate epoch range for this month
    month_start = datetime(load_year, m, 1, 0, 0, 0)
    if m < 12:
        month_end = datetime(load_year, m + 1, 1, 0, 0, 0)
    else:
        month_end = datetime(load_year + 1, 1, 1, 0, 0, 0)
    epoch_start = int(month_start.timestamp())
    epoch_end = int(month_end.timestamp())
    seconds_in_month = epoch_end - epoch_start

    # Generate DataFrame for this month’s facts
    df_month_batch = (
        spark.range(1, facts_per_month + 1)
        .withColumnRenamed("id", "transaction_id")
        .withColumn(
            "txn_timestamp",
            from_unixtime(
                expr(f"CAST(rand() * {seconds_in_month} AS BIGINT) + {epoch_start}")
            ).cast("timestamp"),
        )
        .withColumn("year", lit(load_year))
        .withColumn("month", lit(m))
        .withColumn("day", dayofmonth(col("txn_timestamp")))
        .withColumn("customer_id", expr(f"floor(rand() * {num_customers}) + 1"))
        .withColumn("product_id", expr(f"floor(rand() * {num_products}) + 1"))
        .withColumn("quantity", expr("floor(rand()*10) + 1"))
        .withColumn("unit_price", expr("round(rand()*99 + 1, 2)"))
        .withColumn("amount", expr("round(quantity * unit_price, 2)"))
    )

    # Write or replace partition in fact_transactions
    (
        df_month_batch.orderBy("year", "month", "product_id", "customer_id")
        .write.format("delta")
        .mode("overwrite")
        .option("replaceWhere", f"year = {load_year} AND month = {m}")
        .partitionBy("year", "month")
        .saveAsTable("fact_transactions")
    )
    print(
        f"Loaded fact partition year={load_year}, month={m} with {facts_per_month} rows"
    )

    # 📈  Aggregate for this month only
    # df_agg_month = (
    #    df_month_batch
    #      .groupBy("product_id", "customer_id", "year", "month")
    #      .agg(
    #          spark_sum("quantity").alias("total_quantity"),
    #          spark_sum("amount").alias("total_amount")
    #      )
    # )

    # Write or replace partition in agg_transactions
    # (df_agg_month
    #   .write
    #   .format("delta")
    #   .mode("overwrite")
    #   # .option("replaceWhere", f"year = {load_year} AND month = {m}")
    #   # .partitionBy("year", "month")
    #   .saveAsTable("agg_transactions")
    # )
    # print(f"Loaded agg partition year={load_year}, month={m}")

# ---------------------------------------------------------------------------

StatementMeta(, 88c09812-4909-4bfc-bb57-9cb8fb8d592e, 8, Finished, Available, Finished)

Skipping drop of existing tables.
Loaded fact partition year=2017, month=1 with 8333333 rows
Loaded fact partition year=2017, month=2 with 8333333 rows
Loaded fact partition year=2017, month=3 with 8333333 rows
Loaded fact partition year=2017, month=4 with 8333333 rows
Loaded fact partition year=2017, month=5 with 8333333 rows
Loaded fact partition year=2017, month=6 with 8333333 rows
Loaded fact partition year=2017, month=7 with 8333333 rows
Loaded fact partition year=2017, month=8 with 8333333 rows
Loaded fact partition year=2017, month=9 with 8333333 rows
Loaded fact partition year=2017, month=10 with 8333333 rows
Loaded fact partition year=2017, month=11 with 8333333 rows
Loaded fact partition year=2017, month=12 with 8333333 rows


## Aggregation

In [7]:
df_agg = spark.sql(
    "Select year, month, product_id, customer_id, sum(quantity) as total_quantity, sum(amount) as total_amount  from fact_transactions group by  year, month, product_id, customer_id"
)
(
    df_agg.orderBy("year", "month", "product_id", "customer_id")
    .write.mode("overwrite")
    .format("delta")
    # .partitionBy("year", "month")
    .option("parquet.vorder.enabled", "force_true")
    .saveAsTable("agg_transactions")
)

StatementMeta(, 88c09812-4909-4bfc-bb57-9cb8fb8d592e, 9, Finished, Available, Finished)

## Sanity Checks

In [31]:
# COMMAND ----------
# 🔍  Quick sanity checks (optional)
print("Rows in dim_customer     :", df_customers.count())
print("Rows in dim_day          :", df_day.count())
print("Rows in dim_month        :", df_month.count())
print("Rows in dim_year         :", df_year.count())
print("Rows in dim_product      :", df_products.count())
# Example: count a specific partition (e.g., load_year-Jan)
count_jan = spark.table("fact_transactions").where(f"year={load_year}").count()
print(f"Rows in fact {load_year}     :", count_jan)
count_jan = spark.table("agg_transactions").where(f"year={load_year}").count()
print(f"Rows in agg {load_year}     :", count_jan)

StatementMeta(, 51afabbe-d8f0-45de-92e6-64767a52e1b3, 33, Finished, Available, Finished)

Rows in dim_customer     : 100
Rows in dim_day          : 31
Rows in dim_month        : 12
Rows in dim_year         : 10
Rows in dim_product      : 10
Rows in fact 2016     : 99999996
Rows in agg 2016     : 12000


In [8]:
df_batch = spark.sql("select * from fact_transactions")
(
    df_batch.orderBy("year", "month", "product_id", "customer_id")
    .write.mode("overwrite")
    .format("delta")
    .partitionBy("year", "month")
    .option("parquet.vorder.enabled", "force_true")
    .saveAsTable("fact_transactions")
)

StatementMeta(, 88c09812-4909-4bfc-bb57-9cb8fb8d592e, 10, Finished, Available, Finished)

In [9]:
from delta.tables import DeltaTable

delta_table = DeltaTable.forPath(spark, "Tables/fact_transactions")
delta_table.optimize().executeCompaction()

StatementMeta(, 88c09812-4909-4bfc-bb57-9cb8fb8d592e, 11, Finished, Available, Finished)

DataFrame[path: string, metrics: struct<numFilesAdded:bigint,numFilesRemoved:bigint,numFilesUpdatedWithoutRewrite:bigint,filesAdded:struct<min:bigint,max:bigint,avg:double,totalFiles:bigint,totalSize:bigint>,filesRemoved:struct<min:bigint,max:bigint,avg:double,totalFiles:bigint,totalSize:bigint>,filesUpdatedWithoutRewrite:struct<min:bigint,max:bigint,avg:double,totalFiles:bigint,totalSize:bigint>,filesRemovedBreakdown:array<struct<reason:string,metrics:struct<min:bigint,max:bigint,avg:double,totalFiles:bigint,totalSize:bigint>>>,partitionsOptimized:bigint,zOrderStats:struct<strategyName:string,inputCubeFiles:struct<num:bigint,size:bigint>,inputOtherFiles:struct<num:bigint,size:bigint>,inputNumCubes:bigint,mergedFiles:struct<num:bigint,size:bigint>,numOutputCubes:bigint,mergedNumCubes:bigint>,clusteringStats:struct<inputZCubeFiles:struct<numFiles:bigint,size:bigint>,inputOtherFiles:struct<numFiles:bigint,size:bigint>,inputNumZCubes:bigint,mergedFiles:struct<numFiles:bigint,size:bigint>,

In [10]:
from delta.tables import DeltaTable

delta_table = DeltaTable.forPath(spark, "Tables/agg_transactions")
delta_table.optimize().executeCompaction()

StatementMeta(, 88c09812-4909-4bfc-bb57-9cb8fb8d592e, 12, Finished, Available, Finished)

DataFrame[path: string, metrics: struct<numFilesAdded:bigint,numFilesRemoved:bigint,numFilesUpdatedWithoutRewrite:bigint,filesAdded:struct<min:bigint,max:bigint,avg:double,totalFiles:bigint,totalSize:bigint>,filesRemoved:struct<min:bigint,max:bigint,avg:double,totalFiles:bigint,totalSize:bigint>,filesUpdatedWithoutRewrite:struct<min:bigint,max:bigint,avg:double,totalFiles:bigint,totalSize:bigint>,filesRemovedBreakdown:array<struct<reason:string,metrics:struct<min:bigint,max:bigint,avg:double,totalFiles:bigint,totalSize:bigint>>>,partitionsOptimized:bigint,zOrderStats:struct<strategyName:string,inputCubeFiles:struct<num:bigint,size:bigint>,inputOtherFiles:struct<num:bigint,size:bigint>,inputNumCubes:bigint,mergedFiles:struct<num:bigint,size:bigint>,numOutputCubes:bigint,mergedNumCubes:bigint>,clusteringStats:struct<inputZCubeFiles:struct<numFiles:bigint,size:bigint>,inputOtherFiles:struct<numFiles:bigint,size:bigint>,inputNumZCubes:bigint,mergedFiles:struct<numFiles:bigint,size:bigint>,