In [0]:
from pyspark.sql import functions as F
from pyspark.sql.window import Window
from delta.tables import DeltaTable

# =========================
# CONFIG
# =========================
catalog_name = "electricity-project"
silver_schema = "silver"

input_table = "silver.price_weather_joined"
output_table = "silver.price_features"

LOOKBACK_HOURS = 48

# =========================
# CATALOG + SCHEMA
# =========================
spark.sql(f"USE CATALOG `{catalog_name}`")
spark.sql(f"USE SCHEMA {silver_schema}")

# =========================
# DETERMINE WATERMARK
# =========================
if spark.catalog.tableExists(output_table):
    max_dt = (
        spark.table(output_table)
        .agg(F.max("datetime").alias("max_dt"))
        .collect()[0]["max_dt"]
    )
else:
    max_dt = None

# =========================
# READ INPUT (INCREMENTAL)
# =========================
df = spark.table(input_table)

if max_dt is not None:
    df = df.filter(
        F.col("datetime") >= F.lit(max_dt) - F.expr(f"INTERVAL {LOOKBACK_HOURS} HOURS")
    )

# =========================
# BASIC TIME FEATURES
# =========================
df = (
    df
    .withColumn("hour", F.hour("datetime"))
    .withColumn("day_of_week", F.dayofweek("datetime") - 2)
)

df = df.withColumn(
    "day_of_week",
    F.when(F.col("day_of_week") == -1, 6)
     .otherwise(F.col("day_of_week"))
)

# =========================
# WINDOWS (SAFE PARTITIONING)
# =========================
trend_window = (
    Window
    .partitionBy(F.to_date("datetime"))
    .orderBy("datetime")
)

lag_window = trend_window

# =========================
# TREND
# =========================
df = df.withColumn(
    "trend",
    F.row_number().over(trend_window) - 1
)

# =========================
# LAG FEATURES
# =========================
df = df.withColumn(
    "price_lag_24",
    F.lag("price_nok", 24).over(lag_window)
)

# =========================
# ONE-HOT ENCODING (HOUR)
# =========================
for h in range(1, 24):
    df = df.withColumn(
        f"hour_{h}",
        F.when(F.col("hour") == h, 1).otherwise(0)
    )

df = df.drop("hour")

# =========================
# ONE-HOT ENCODING (DAY OF WEEK)
# =========================
for d in range(1, 7):
    df = df.withColumn(
        f"day_of_week_{d}",
        F.when(F.col("day_of_week") == d, 1).otherwise(0)
    )

df = df.drop("day_of_week")

# =========================
# FINAL COLUMN ORDER
# =========================
final_updates_df = df.select(
    "datetime",
    "price_nok",
    "temperature",
    "trend",
    "price_lag_24",
    *[f"hour_{h}" for h in range(1, 24)],
    *[f"day_of_week_{d}" for d in range(1, 7)]
)

# =========================
# MERGE INTO SILVER
# =========================
if spark.catalog.tableExists(output_table):

    delta_out = DeltaTable.forName(spark, output_table)

    (
        delta_out.alias("t")
        .merge(
            final_updates_df.alias("s"),
            "t.datetime = s.datetime"
        )
        .whenMatchedUpdateAll()
        .whenNotMatchedInsertAll()
        .execute()
    )

else:
    (
        final_updates_df
        .write
        .format("delta")
        .mode("overwrite")
        .saveAsTable(output_table)
    )


In [0]:
%sql
SELECT *
FROM `electricity-project`.silver.price_features
ORDER BY datetime
LIMIT 10;


In [0]:
%sql
SELECT *
FROM `electricity-project`.silver.price_features
ORDER BY datetime DESC
LIMIT 10;

In [0]:
%sql
SELECT
  count(*),
  sum(CASE WHEN price_lag_24 IS NULL THEN 1 ELSE 0 END) AS lag_nulls
FROM `electricity-project`.silver.price_features;
