In [0]:
from databricks.feature_engineering import FeatureEngineeringClient
from pyspark.sql.functions import max, col, lit, expr
import datetime as dt
import mlflow
fe = FeatureEngineeringClient()

In [0]:
%run ../config/variables

## Load data and declare variables

In [0]:
n = 10
today = dt.date.today()
initial_date = today - dt.timedelta(days=n)
df_dataset = fe.read_table(name=f'{catalog_name}.{gold_schema_name}.features_demand_forecast').filter(col('event_date') >= initial_date)
model = mlflow.spark.load_model(f'models:/{catalog_name}.{gold_schema_name}.demand_forecast_model@prod')

## Preprocessing data for inference

In [0]:
df_to_predict = (
    df_dataset
    .groupBy('district').agg(max('event_datetime').alias('event_datetime'))
    .join(df_dataset, on=['district','event_datetime'], how='left')
    .withColumn('event_datetime', expr("event_datetime + INTERVAL 1 HOUR"))
    .withColumn('event_date', expr("date(event_datetime)"))
    .withColumn('event_hour', expr("hour(event_datetime)"))
    .withColumn('event_weekday', expr("dayofweek(event_datetime)"))
    .withColumn('prev_quantity_products_6', col('prev_quantity_products_5'))
    .withColumn('prev_quantity_products_5', col('prev_quantity_products_4'))
    .withColumn('prev_quantity_products_4', col('prev_quantity_products_3'))
    .withColumn('prev_quantity_products_3', col('prev_quantity_products_2'))
    .withColumn('prev_quantity_products_2', col('prev_quantity_products'))
    .withColumn('prev_quantity_products', col('sum_quantity_products'))
    .drop('sum_quantity_products')
)

## Model inference

In [0]:
predictions = model.transform(df_to_predict)

In [0]:
predictions = (
    predictions
    .withColumn('sum_quantity_products', col('prediction'))
    .drop(*['district_index','district_vec','event_weekday_vec','event_hour_vec','features1','features2','features3','features4','features5','features6','prediction'])
)

In [0]:
final_df = df_dataset.unionByName(predictions).select('district', 'event_datetime', 'sum_quantity_products')
final_df.createOrReplaceTempView('final_df')

## Save predictions

In [0]:
spark.sql(f"""
    MERGE INTO {catalog_name}.{gold_schema_name}.demand_forecast AS target
        USING final_df AS source
        ON target.district = source.district AND target.event_datetime = source.event_datetime
        WHEN MATCHED THEN
        UPDATE SET target.sum_quantity_products = source.sum_quantity_products
        WHEN NOT MATCHED THEN
        INSERT (district, event_datetime, sum_quantity_products)
        VALUES (source.district, source.event_datetime, source.sum_quantity_products)
          """)