In [0]:
##################################################################################
# Prophet Forecasting Model Training Notebook
#
# This notebook trains a Prophet forecasting model and registers it to Unity Catalog.
#
# Parameters:
# * env (required)              - Environment the notebook is run in (dev, staging, or prod)
# * catalog (required)          - Catalog name for the training data
# * schema (required)           - Schema name for the training data  
# * table (required)            - Table name for the training data
# * forecast_horizon (required) - Number of periods to forecast
# * experiment_name (required)  - MLflow experiment name for the training runs
# * model_name (required)       - Three-level name (<catalog>.<schema>.<model_name>) to register the trained model in Unity Catalog
# * serving_endpoint_name       - Optional name for the serving endpoint
##################################################################################

# MAGIC %load_ext autoreload
# MAGIC %autoreload 2

In [None]:
# DBTITLE 1, Install dependencies
# MAGIC %pip install prophet databricks-sdk mlflow grpcio grpcio-status pandas
dbutils.library.restartPython()


In [0]:
# DBTITLE 1, Notebook arguments
# List of input args needed to run this notebook as a job.
# Provide them via DB widgets or notebook arguments.

# Notebook Environment
dbutils.widgets.dropdown("env", "dev", ["dev", "staging", "prod"], "Environment Name")
env = dbutils.widgets.get("env")

# Training data location
dbutils.widgets.text("catalog", "johannes_oehler", label="Data Catalog")
dbutils.widgets.text("schema", "vectorlab", label="Data Schema")
dbutils.widgets.text("table", "forecast_data", label="Data Table")

# Forecast configuration
dbutils.widgets.text("forecast_horizon", "10", label="Forecast Horizon")

# MLflow experiment name
dbutils.widgets.text(
    "experiment_name",
    f"/dev-prophet-forecast-experiment",
    label="MLflow experiment name",
)

# Unity Catalog registered model name (three-level namespace)
dbutils.widgets.text(
    "model_name", 
    "johannes_oehler.vectorlab.prophet_forecast", 
    label="Full (Three-Level) Model Name"
)

# Optional: Serving endpoint name
dbutils.widgets.text(
    "serving_endpoint_name",
    "forecast_joe",
    label="Serving Endpoint Name"
)

In [None]:
# DBTITLE 1, Define input and output variables
catalog = dbutils.widgets.get("catalog")
schema = dbutils.widgets.get("schema")
table = dbutils.widgets.get("table")
forecast_horizon = int(dbutils.widgets.get("forecast_horizon"))
experiment_name = dbutils.widgets.get("experiment_name")
model_name = dbutils.widgets.get("model_name")
serving_endpoint_name = dbutils.widgets.get("serving_endpoint_name")


In [None]:
# DBTITLE 1, Set MLflow experiment and Unity Catalog registry
import mlflow
from mlflow.tracking import MlflowClient

mlflow.set_experiment(experiment_name)
mlflow.set_registry_uri('databricks-uc')

print(f"Experiment: {experiment_name}")
print(f"Model will be registered to: {model_name}")


In [None]:
# DBTITLE 1, Helper function
def get_latest_model_version(model_name):
    """Get the latest version number for a registered model"""
    latest_version = 1
    mlflow_client = MlflowClient()
    for mv in mlflow_client.search_model_versions(f"name='{model_name}'"):
        version_int = int(mv.version)
        if version_int > latest_version:
            latest_version = version_int
    return latest_version


In [0]:
# DBTITLE 1, Load and split data
query = f"SELECT date, store, SUM(sales) as sales FROM {catalog}.{schema}.{table} GROUP BY date, store ORDER BY date desc"

df = spark.sql(query)

# Choose a single store to make the calculations simpler
df = df.filter(df.store == 1)

# train-test-split
train_df = df.orderBy(df.date.asc()).limit(df.count() - forecast_horizon).orderBy(df.date.desc())
test_df = df.orderBy(df.date.desc()).limit(forecast_horizon).toPandas()

train_df.show(5)
test_df.head(5)

In [0]:
# DBTITLE 1, Clean data and remove outliers
from pyspark.sql.functions import col, lit

# Dropping rows with missing values in the 'sales' column
cleaned_df = train_df.na.drop(subset=["sales"]) 
cleaned_df.show(5)

# Calculating IQR and defining bounds for outliers
quartiles = cleaned_df.approxQuantile("sales", [0.25, 0.75], 0.05) 
IQR = quartiles[1] - quartiles[0]
lower_bound = 0
upper_bound = quartiles[1] + 1.5 * IQR

# Filtering out outliers
no_outliers_df = cleaned_df.filter(
    (col("sales") > lit(lower_bound)) 
    & (col("sales") <= lit(upper_bound)) 
)

# Showing the updated DataFrame
no_outliers_df.show(5)

In [0]:
# DBTITLE 1, Train Prophet model
from prophet import Prophet
from pyspark.sql.functions import col, to_date

# Prophet requires at the minimum 2 columns - ds & y
train_df = no_outliers_df.select(to_date(col("date")).alias("ds"), col("store"), col("sales").alias("y").cast("double")).orderBy(col("ds").desc())

# set model parameters
prophet_model = Prophet(
  interval_width=0.95,
  growth='linear',
  daily_seasonality=True,
  weekly_seasonality=True,
  yearly_seasonality=True,
  seasonality_mode='additive'
  )
 
# fit the model to historical data
history_pd = train_df.toPandas()
prophet_model.fit(history_pd)

In [0]:
# MAGIC %md
# MAGIC Train a Prophet forecasting model, then log and register it to Unity Catalog with MLflow.


In [0]:
# DBTITLE 1, Create model wrapper and signature
import mlflow
from mlflow.pyfunc import PythonModel
from mlflow.models.signature import infer_signature
import pandas as pd

# Wrapper class for Prophet model - NO PICKLE NEEDED!
# MLflow handles serialization automatically via save_model/load_model
class ProphetWrapper(PythonModel):
    def __init__(self, model, forecast_horizon):
        """
        Initialize with the trained Prophet model directly.
        No need for pickle - MLflow handles serialization.
        """
        self.model = model
        self.forecast_horizon = forecast_horizon

    def predict(self, context, model_input: pd.DataFrame) -> pd.DataFrame:
        """
        Generate forecasts using the Prophet model.
        
        Args:
            model_input: DataFrame with 'ds' and 'y' columns (historical data)
        
        Returns:
            DataFrame with forecast columns: ds, yhat, yhat_lower, yhat_upper
        """
        # Generate future dataframe based on forecast horizon
        future_pd = self.model.make_future_dataframe(
            periods=self.forecast_horizon,
            freq="d",
            include_history=True
        )
        
        # Make predictions
        forecast_pd = self.model.predict(future_pd)
        
        # Return relevant forecast columns
        return forecast_pd[["ds", "yhat", "yhat_lower", "yhat_upper"]]

# Wrap the trained Prophet model (no pickle needed!)
wrapped_model = ProphetWrapper(prophet_model, forecast_horizon)

# Create input example and signature for MLflow
input_example = history_pd.head()[["ds", "y"]]
signature = infer_signature(input_example, prophet_model.predict(input_example))

In [0]:
# DBTITLE 1, Train and register model
# Start MLflow run and log the model with Unity Catalog registration
with mlflow.start_run(run_name="prophet_training") as run:
    # Log model parameters
    mlflow.log_param("forecast_horizon", forecast_horizon)
    mlflow.log_param("interval_width", 0.95)
    mlflow.log_param("growth", "linear")
    mlflow.log_param("daily_seasonality", True)
    mlflow.log_param("weekly_seasonality", True)
    mlflow.log_param("yearly_seasonality", True)
    mlflow.log_param("seasonality_mode", "additive")
    
    # Log training data info
    mlflow.log_param("training_data_source", f"{catalog}.{schema}.{table}")
    mlflow.log_param("training_samples", len(history_pd))
    
    # Log the model and register it to Unity Catalog
    # Specify explicit dependencies to ensure serving environment compatibility
    mlflow.pyfunc.log_model(
        artifact_path="prophet_model",
        python_model=wrapped_model,
        signature=signature,
        input_example=input_example,
        registered_model_name=model_name,  # This registers the model to Unity Catalog
        pip_requirements=[
            "prophet",
            "pandas<2.0.0",  # Pin to pandas 1.x for compatibility
            "numpy",
        ]
    )

    run_id = run.info.run_id
    print(f"Training run logged with run_id: {run_id}")
    print(f"Model registered to Unity Catalog as: {model_name}")


In [None]:
# DBTITLE 1, Return model info for downstream tasks
# Get the latest model version that was just registered
model_version = get_latest_model_version(model_name)
model_uri = f"models:/{model_name}/{model_version}"

# Set task values for downstream tasks (e.g., validation, deployment)
dbutils.jobs.taskValues.set("model_uri", model_uri)
dbutils.jobs.taskValues.set("model_name", model_name)
dbutils.jobs.taskValues.set("model_version", model_version)

print(f"Model URI: {model_uri}")
print(f"Model Version: {model_version}")

# Exit with model URI for use in workflows
dbutils.notebook.exit(model_uri)
