In [1]:
import sys
import os
import mlflow
import json
import pandas as pd
import numpy as np
from urllib.parse import urlparse

from prophet import Prophet, serialize
from prophet.diagnostics import cross_validation, performance_metrics

import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg')

import logging

In [2]:
SOURCE_DATA = (
    "https://raw.githubusercontent.com/facebook/prophet/master/examples/example_retail_sales.csv"
)
ARTIFACT_PATH = "model"
np.random.seed(12345)

In [3]:
def extract_params(pr_model):
    return {attr: getattr(pr_model, attr) for attr in serialize.SIMPLE_ATTRIBUTES}

In [30]:
sales_data = pd.read_csv(SOURCE_DATA)[:100]

In [None]:
with mlflow.start_run():

    model = Prophet().fit(sales_data)

    params = extract_params(model)

    metric_keys = ["mse", "rmse", "mae", "mape"]
    metrics_raw = cross_validation(
        model=model,
        horizon="365 days",
#         period="180 days",
        initial="710 days",
        parallel="threads",
#         disable_tqdm=True,
    )
    cv_metrics = performance_metrics(metrics_raw)
    metrics = {k: cv_metrics[k].mean() for k in metric_keys}

    print(f"Logged Metrics: \n{json.dumps(metrics, indent=2)}")
    print(f"Logged Params: \n{json.dumps(params, indent=2)}")

    mlflow.prophet.log_model(model, artifact_path=ARTIFACT_PATH)
    mlflow.log_params(params)
    mlflow.log_metrics(metrics)
    model_uri = mlflow.get_artifact_uri(ARTIFACT_PATH)
    print(f"Model artifact logged to: {model_uri}")

The git executable must be specified in one of the following ways:
    - be included in your $PATH
    - be set via $GIT_PYTHON_GIT_EXECUTABLE
    - explicitly set via git.refresh()

All git commands will error until this is rectified.

$GIT_PYTHON_REFRESH environment variable. Use one of the following values:
    - error|e|raise|r|2: for a raised exception

Example:
    export GIT_PYTHON_REFRESH=quiet

INFO:prophet:Disabling weekly seasonality. Run prophet with weekly_seasonality=True to override this.
INFO:prophet:Disabling daily seasonality. Run prophet with daily_seasonality=True to override this.
INFO:prophet:Making 44 forecasts with cutoffs between 1994-02-21 00:00:00 and 2015-05-02 00:00:00
INFO:prophet:Applying in parallel with <concurrent.futures.thread.ThreadPoolExecutor object at 0x000001E9B14336A0>
INFO:prophet:n_changepoints greater than number of observations. Using 19.
INFO:prophet:n_changepoints greater than number of observations. Using 24.


In [6]:
loaded_model = mlflow.prophet.load_model(model_uri)

forecast = loaded_model.predict(loaded_model.make_future_dataframe(60))

print(f"forecast:\n${forecast.head(30)}")

forecast:
$           ds          trend     yhat_lower     yhat_upper    trend_lower  \
0  1992-01-01  162809.824099  118950.209105  138650.701855  162809.824099   
1  1992-02-01  163861.470293  123977.690665  144162.996529  163861.470293   
2  1992-03-01  164845.268346  159040.506315  180302.132918  164845.268346   
3  1992-04-01  165896.914540  152813.085351  172616.620895  165896.914540   
4  1992-05-01  166914.636663  169409.986353  189547.809171  166914.636663   
5  1992-06-01  167966.282857  160613.277618  181361.740627  167966.282857   
6  1992-07-01  168984.004981  162005.145425  182230.984217  168984.004981   
7  1992-08-01  170035.651175  168597.744698  188928.775884  170035.651175   
8  1992-09-01  171087.297369  148765.637361  169913.250614  171087.297369   
9  1992-10-01  172105.019492  159613.453504  180084.298847  172105.019492   
10 1992-11-01  173156.665985  162340.294550  182470.813110  173156.665985   
11 1992-12-01  174174.388398  208432.185711  228303.502804  17417