In [0]:
from pyspark.sql import functions as F
from delta.tables import DeltaTable
from datetime import timedelta

# =========================
# CONFIG
# =========================
catalog_name = "electricity-project"
gold_schema = "gold"

features_table = "silver.price_features"
params_table   = "gold.price_model_parameters"
output_table   = "gold.day_ahead_price_forecast"

LOOKBACK_HOURS = 48
PREDICT_HOURS  = 24

# =========================
# CATALOG + SCHEMA
# =========================
spark.sql(f"USE CATALOG `{catalog_name}`")
spark.sql(f"CREATE SCHEMA IF NOT EXISTS {gold_schema}")
spark.sql(f"USE SCHEMA {gold_schema}")

# =========================
# LOAD LATEST MODEL
# =========================
params_df = spark.table(params_table)

latest_version = params_df.select(F.max("model_version")).collect()[0][0]
if latest_version is None:
    raise ValueError("No model versions found")

model_params = params_df.filter(F.col("model_version") == latest_version)

intercept = float(model_params.select("intercept").limit(1).collect()[0][0])

coef_dict = {
    r["feature_name"]: float(r["coefficient"])
    for r in model_params.select("feature_name", "coefficient").collect()
}
coef_dict.pop("intercept", None)

model_features = list(coef_dict.keys())

# =========================
# LOAD FEATURES
# =========================
df = spark.table(features_table)

max_dt = df.select(F.max("datetime")).collect()[0][0]
if max_dt is None:
    raise ValueError("Feature table empty")

# =========================
# BUILD LAGS (EXPLICIT, MATCH TRAINING)
# =========================
base = df.alias("b")

lag_src = (
    df.select(
        F.col("datetime").alias("dt_src"),
        F.col("price_nok").alias("price_lag_24"),
        F.col("temperature").alias("temperature_lag_24")
    )
    .alias("l")
)

joined = (
    base.join(
        lag_src,
        F.col("b.datetime") == F.col("l.dt_src") + F.expr("INTERVAL 24 HOURS"),
        "left"
    )
    .select(
        # core timeline
        F.col("b.datetime").alias("datetime"),

        # rebuilt lags (AUTHORITATIVE)
        F.col("l.price_lag_24"),
        F.col("l.temperature_lag_24"),

        # everything else from base EXCEPT old lag columns
        *[
            F.col(f"b.{c}")
            for c in base.columns
            if c not in ("price_lag_24", "temperature_lag_24")
        ]
    )
)

# =========================
# FORECAST WINDOW
# Need 48h to predict last 24h
# =========================
window_start = max_dt - timedelta(hours=LOOKBACK_HOURS - 1)
predict_start = max_dt - timedelta(hours=PREDICT_HOURS - 1)

forecast_df = joined.filter(
    (F.col("b.datetime") >= F.lit(window_start)) &
    (F.col("b.datetime") <= F.lit(max_dt))
)

forecast_for_date = F.to_date(F.lit(max_dt))

# =========================
# ENSURE MODEL FEATURES EXIST
# =========================
dummy_prefixes = ("hour_", "day_of_week_")

for f in model_features:
    if f not in forecast_df.columns:
        if f.startswith(dummy_prefixes):
            forecast_df = forecast_df.withColumn(f, F.lit(0.0))
        else:
            raise ValueError(f"Missing required feature: {f}")

# =========================
# STRICT NON-NULL CHECK
# =========================
required_non_null = [
    c for c in ["price_lag_24", "temperature_lag_24", "trend"]
    if c in model_features
]

rows_before = forecast_df.count()
forecast_df = forecast_df.dropna(subset=required_non_null)
rows_after = forecast_df.count()

if rows_after == 0:
    forecast_df.select(
        F.count("*").alias("rows"),
        *[
            F.sum(F.when(F.col(c).isNull(), 1).otherwise(0)).alias(f"null_{c}")
            for c in required_non_null
        ]
    ).show(truncate=False)

    raise ValueError(
        f"Forecast window unusable. Rows before={rows_before}, after={rows_after}"
    )

# =========================
# APPLY LINEAR MODEL
# =========================
prediction_expr = F.lit(intercept)

for feature, coef in coef_dict.items():
    if feature.startswith(dummy_prefixes):
        prediction_expr += F.coalesce(F.col(feature), F.lit(0.0)) * F.lit(coef)
    else:
        prediction_expr += F.col(feature) * F.lit(coef)

result_df = (
    forecast_df
    .filter(F.col("b.datetime") >= F.lit(predict_start))
    .withColumn("predicted_price", prediction_expr)
    .withColumn("model_version", F.lit(latest_version))
    .withColumn("forecast_for_date", forecast_for_date)
    .withColumn("generated_at", F.current_timestamp())
    .select(
        F.col("b.datetime").alias("datetime"),
        "predicted_price",
        "model_version",
        "forecast_for_date",
        "generated_at"
    )
    .orderBy("datetime")
)

# =========================
# WRITE FORECAST
# =========================
if spark.catalog.tableExists(output_table):
    DeltaTable.forName(spark, output_table) \
        .alias("t") \
        .merge(
            result_df.alias("s"),
            """
            t.datetime = s.datetime
            AND t.model_version = s.model_version
            AND t.forecast_for_date = s.forecast_for_date
            """
        ) \
        .whenMatchedUpdateAll() \
        .whenNotMatchedInsertAll() \
        .execute()
else:
    result_df.write.format("delta").mode("overwrite").saveAsTable(output_table)

print("Forecast written for model_version:", latest_version)


In [0]:
# %sql
# DROP TABLE `electricity-project`.gold.day_ahead_price_forecast




In [0]:
 #%sql
 #DROP TABLE IF EXISTS `electricity-project`.gold.day_ahead_price_forecast;


In [0]:
%sql
SELECT
  MIN(datetime) AS min_dt,
  MAX(datetime) AS max_dt,
  COUNT(*)       AS n_rows
FROM `electricity-project`.bronze.weather_observed;


In [0]:
%sql
SELECT
  MIN(datetime) AS min_dt,
  MAX(datetime) AS max_dt,
  COUNT(*)       AS n_rows
FROM `electricity-project`.bronze.weather_observed;


In [0]:
%sql
SELECT datetime, COUNT(*)
FROM `electricity-project`.silver.price_features
GROUP BY datetime
HAVING COUNT(*) > 1;


In [0]:
%sql
SELECT
  COUNT(*)                                   AS total_rows,
  COUNT(price_lag_24)                        AS lag_available,
  COUNT(temperature)                         AS temp_available
FROM `electricity-project`.silver.price_features;


In [0]:
%sql
SELECT
  datetime,
  price_nok,
  price_lag_24,
  temperature,
  hour_12,
  day_of_week_3
FROM `electricity-project`.silver.price_features
ORDER BY datetime
LIMIT 10;


In [0]:
%sql
SELECT
  MIN(datetime) AS train_start,
  MAX(datetime) AS train_end,
  COUNT(*)       AS n_rows
FROM `electricity-project`.gold.price_model_training_data;


In [0]:
%sql
SELECT
  model_version,
  COUNT(*) AS n_features,
  MIN(trained_at) AS trained_at
FROM `electricity-project`.gold.price_model_parameters
GROUP BY model_version
ORDER BY model_version;


In [0]:
%sql
SELECT
  a.datetime,
  a.price_nok AS actual,
  f.predicted_price,
  ABS(a.price_nok - f.predicted_price) AS abs_error
FROM `electricity-project`.gold.actual_prices a
JOIN `electricity-project`.gold.day_ahead_price_forecast f
  USING (datetime)
ORDER BY a.datetime ASC
LIMIT 10;


In [0]:
%sql
SELECT MAX(datetime)
FROM `electricity-project`.gold.actual_prices;


In [0]:
%sql
SELECT MIN(datetime), MAX(datetime), COUNT(*)
FROM `electricity-project`.gold.day_ahead_price_forecast;


In [0]:
%sql
SELECT *
FROM `electricity-project`.gold.day_ahead_price_forecast;