In [7]:
#Importando as bibliotecas necessárias
import sys
import os
import mlflow
from urllib.parse import urlparse

#Manipulação de dados
import json
import pandas as pd
import numpy as np

#Ignorar avisos de atualização, etc
import warnings
warnings.filterwarnings("ignore")

from prophet import Prophet, serialize
from prophet.diagnostics import cross_validation, performance_metrics

#Gráficos
import seaborn as sns

import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg')

import logging

In [2]:
logging.basicConfig(level=logging.WARN)
logger = logging.getLogger(__name__)  # só para pegar a sessao e quem fez o log

ARTIFACT_PATH = "model"
mlflow.set_tracking_uri('http://localhost:5000')  # porta do mlflow
# #mlflow.set_tracking_uri('mysql://root:root@localhost:3306/mlflow')  # caso fosse um banco mysql por exemplo
mlflow.set_experiment(experiment_name='test_mlflow')  # nome do experimento

# nome das tags
tags = {
        "Projeto": "MLflow Teste",
        "ml_model": "Prophet"
       }

In [3]:
def extract_params(pr_model):
    return {attr: getattr(pr_model, attr) for attr in serialize.SIMPLE_ATTRIBUTES}

In [5]:
SOURCE_DATA = ("https://raw.githubusercontent.com/facebook/prophet/master/examples/example_retail_sales.csv")
sales_data = pd.read_csv(SOURCE_DATA)

In [8]:
with mlflow.start_run(run_name='prophet_v1'):

    model = Prophet().fit(sales_data)
    params = extract_params(model)

    metric_keys = ["mse", "rmse", "mae", "mape"]
    metrics_raw = cross_validation(
        model=model,
        initial="1825 days",  # janela de tempo que usaremos para train
        horizon="365 days",  # o quanto vamos prever para frente
#         period="180 days",  # o quanto vamos pular depois do primeiro treino (se nao passarmos esse param ele faz 50% do horizon)
        parallel="threads",
#         disable_tqdm=True,
    )
    cv_metrics = performance_metrics(metrics_raw)
    metrics = {k: cv_metrics[k].mean() for k in metric_keys}

    print(f"Logged Metrics: \n{json.dumps(metrics, indent=2)}")
    print(f"Logged Params: \n{json.dumps(params, indent=2)}")

    mlflow.prophet.log_model(model, artifact_path=ARTIFACT_PATH)
    mlflow.log_params(params)
    mlflow.log_metrics(metrics)
    model_uri = mlflow.get_artifact_uri(ARTIFACT_PATH)
    print(f"Model artifact logged to: {model_uri}")

INFO:prophet:Disabling weekly seasonality. Run prophet with weekly_seasonality=True to override this.
INFO:prophet:Disabling daily seasonality. Run prophet with daily_seasonality=True to override this.
INFO:prophet:Making 37 forecasts with cutoffs between 1997-05-06 00:00:00 and 2015-05-02 00:00:00
INFO:prophet:Applying in parallel with <concurrent.futures.thread.ThreadPoolExecutor object at 0x000002359E26DF10>


Logged Metrics: 
{
  "mse": 393213342.5809708,
  "rmse": 19506.712984496884,
  "mae": 14287.717121383135,
  "mape": 0.041057912754979824
}
Logged Params: 
{
  "growth": "linear",
  "n_changepoints": 25,
  "specified_changepoints": false,
  "changepoint_range": 0.8,
  "yearly_seasonality": "auto",
  "weekly_seasonality": "auto",
  "daily_seasonality": "auto",
  "seasonality_mode": "additive",
  "seasonality_prior_scale": 10.0,
  "changepoint_prior_scale": 0.05,
  "holidays_prior_scale": 10.0,
  "mcmc_samples": 0,
  "interval_width": 0.8,
  "uncertainty_samples": 1000,
  "y_scale": 518253.0,
  "logistic_floor": false,
  "country_holidays": null,
  "component_modes": {
    "additive": [
      "yearly",
      "additive_terms",
      "extra_regressors_additive",
      "holidays"
    ],
    "multiplicative": [
      "multiplicative_terms",
      "extra_regressors_multiplicative"
    ]
  }
}
Model artifact logged to: ./artifacts/1/bdca9cddd899437e95ba68810b046065/artifacts/model


---

## Carregando o Modelo de Produção e Fazendo Previsões

In [None]:
import mlflow
import pandas as pd

loaded_model = mlflow.prophet.load_model(model_uri)

forecast = loaded_model.predict(loaded_model.make_future_dataframe(60))

print(f"forecast:\n${forecast.head(30)}")

In [1]:
import mlflow
import pandas as pd

mlflow.set_tracking_uri('http://localhost:5000')

# Carregando o modelo que está em produção
# live = nome do modelo registrado
# Production é o stado dele no momento (poderia ser Staging por exemplo)
logged_model = 'models:/prophet_test/Production'
loaded_model = mlflow.prophet.load_model(logged_model)

In [6]:
sales_data

Unnamed: 0,ds,y
0,1992-01-01,146376
1,1992-02-01,147079
2,1992-03-01,159336
3,1992-04-01,163669
4,1992-05-01,170068
...,...,...
288,2016-01-01,400928
289,2016-02-01,413554
290,2016-03-01,460093
291,2016-04-01,450935


In [10]:
forecast = loaded_model.predict(loaded_model.make_future_dataframe(10, freq='M'))

forecast.tail(20)

Unnamed: 0,ds,trend,yhat_lower,yhat_upper,trend_lower,trend_upper,additive_terms,additive_terms_lower,additive_terms_upper,yearly,yearly_lower,yearly_upper,multiplicative_terms,multiplicative_terms_lower,multiplicative_terms_upper,yhat
283,2015-08-01,450161.601156,449205.693081,470482.825565,450161.601156,450161.601156,10035.337568,10035.337568,10035.337568,10035.337568,10035.337568,10035.337568,0.0,0.0,0.0,460196.938724
284,2015-09-01,451596.373349,427421.373515,447882.669862,451596.373349,451596.373349,-13502.548043,-13502.548043,-13502.548043,-13502.548043,-13502.548043,-13502.548043,0.0,0.0,0.0,438093.825307
285,2015-10-01,452984.862569,437213.526935,457198.686679,452984.862569,452984.862569,-5844.905559,-5844.905559,-5844.905559,-5844.905559,-5844.905559,-5844.905559,0.0,0.0,0.0,447139.95701
286,2015-11-01,454419.634763,441445.865104,462815.483914,454419.634763,454419.634763,-2536.217361,-2536.217361,-2536.217361,-2536.217361,-2536.217361,-2536.217361,0.0,0.0,0.0,451883.417402
287,2015-12-01,455808.123982,496206.985436,517602.01579,455808.123982,455808.123982,50868.360738,50868.360738,50868.360738,50868.360738,50868.360738,50868.360738,0.0,0.0,0.0,506676.48472
288,2016-01-01,457242.896176,413193.622369,433412.468376,457242.896176,457242.896176,-34207.792628,-34207.792628,-34207.792628,-34207.792628,-34207.792628,-34207.792628,0.0,0.0,0.0,423035.103548
289,2016-02-01,458677.66837,417919.869091,439064.486341,458677.66837,458677.66837,-30543.393274,-30543.393274,-30543.393274,-30543.393274,-30543.393274,-30543.393274,0.0,0.0,0.0,428134.275096
290,2016-03-01,460019.874615,454437.308815,474882.396322,460019.874615,460019.874615,4409.69771,4409.69771,4409.69771,4409.69771,4409.69771,4409.69771,0.0,0.0,0.0,464429.572325
291,2016-04-01,461454.646809,448334.175289,468658.068817,461454.646809,461454.646809,-3180.502879,-3180.502879,-3180.502879,-3180.502879,-3180.502879,-3180.502879,0.0,0.0,0.0,458274.14393
292,2016-05-01,462843.136029,465051.109635,485739.491373,462843.136029,462843.136029,12324.644434,12324.644434,12324.644434,12324.644434,12324.644434,12324.644434,0.0,0.0,0.0,475167.780463
