In [1]:
from typing import Tuple

import pandas

from prophet import Prophet
from prophet.serialize import model_to_json, model_from_json

from sklearn.metrics import mean_squared_error

  from .autonotebook import tqdm as notebook_tqdm
Importing plotly failed. Interactive plots will not work.


In [2]:
country = 'usa'

df = pandas.read_excel(io = f'../../../data/processed/{country}.xlsx')

In [3]:
def make_dataset(df_processed: pandas.DataFrame, df_covid_measures: pandas.DataFrame = pandas.DataFrame()) -> Tuple[pandas.DataFrame, pandas.DataFrame]:
    df_mrd = df_processed[['Time', 'Unemployment_Rate_TOT']].rename(
        columns = {'Time': 'ds', 'Unemployment_Rate_TOT': 'y'}
    )
    df_mrd = df_mrd.drop(index = df_mrd[pandas.isnull(df_mrd['y'])].index, inplace = False)
    return df_mrd, df_covid_measures

In [4]:
df_mrd, _ = make_dataset(df)

In [5]:
def train_test_split(df_mrd: pandas.DataFrame, test_size: int = 12) -> Tuple[pandas.DataFrame, pandas.DataFrame]:
    df_test = df_mrd.tail(test_size)
    df_train = df_mrd.drop(index = df_mrd.tail(test_size).index, inplace = False)
    return df_train, df_test

In [6]:
df_train, df_test = train_test_split(df_mrd, 12)

In [7]:
# param_grid = {  
#     'changepoint_prior_scale': [0.001, 0.01, 0.1, 0.5],
#     'seasonality_prior_scale': [0.01, 0.1, 1.0, 10.0],
# }

model = Prophet().fit(df_train)

03:02:23 - cmdstanpy - INFO - Chain [1] start processing
03:02:23 - cmdstanpy - INFO - Chain [1] done processing


In [8]:
def test_model(df_test: pandas.DataFrame, model: Prophet) -> Tuple[pandas.DataFrame, float]:
    "return predicted values and rmse"
    df_predicted: pandas.DataFrame = model.predict(df_test)
    rmse: float = mean_squared_error(y_true = df_test['y'], y_pred = df_predicted['yhat'], squared = False)
    return df_predicted, rmse

In [9]:
df_predicted, rmse = test_model(df_test, model)

In [10]:
df_predicted

Unnamed: 0,ds,trend,yhat_lower,yhat_upper,trend_lower,trend_upper,additive_terms,additive_terms_lower,additive_terms_upper,yearly,yearly_lower,yearly_upper,multiplicative_terms,multiplicative_terms_lower,multiplicative_terms_upper,yhat
0,2022-04-01,5.574403,3.977121,7.539673,5.574403,5.574403,0.148965,0.148965,0.148965,0.148965,0.148965,0.148965,0.0,0.0,0.0,5.723369
1,2022-05-01,5.570483,3.948236,7.408052,5.570483,5.570483,0.122704,0.122704,0.122704,0.122704,0.122704,0.122704,0.0,0.0,0.0,5.693187
2,2022-06-01,5.566431,3.858294,7.446056,5.566431,5.566431,0.11115,0.11115,0.11115,0.11115,0.11115,0.11115,0.0,0.0,0.0,5.677581
3,2022-07-01,5.562511,3.935232,7.361435,5.562511,5.562511,0.071347,0.071347,0.071347,0.071347,0.071347,0.071347,0.0,0.0,0.0,5.633858
4,2022-08-01,5.558459,3.850108,7.45677,5.558459,5.558459,0.040512,0.040512,0.040512,0.040512,0.040512,0.040512,0.0,0.0,0.0,5.598972
5,2022-09-01,5.554408,3.811004,7.46081,5.554408,5.554408,0.020362,0.020362,0.020362,0.020362,0.020362,0.020362,0.0,0.0,0.0,5.57477
6,2022-10-01,5.550487,3.76537,7.415964,5.549899,5.550487,-0.002385,-0.002385,-0.002385,-0.002385,-0.002385,-0.002385,0.0,0.0,0.0,5.548102
7,2022-11-01,5.546436,3.819486,7.228677,5.544991,5.546436,0.007576,0.007576,0.007576,0.007576,0.007576,0.007576,0.0,0.0,0.0,5.554012
8,2022-12-01,5.542515,3.762848,7.179724,5.539773,5.542519,-0.015939,-0.015939,-0.015939,-0.015939,-0.015939,-0.015939,0.0,0.0,0.0,5.526577
9,2023-01-01,5.538464,3.768474,7.362314,5.53453,5.539325,-0.031286,-0.031286,-0.031286,-0.031286,-0.031286,-0.031286,0.0,0.0,0.0,5.507177


In [11]:
rmse

2.0350413964963825

In [12]:
with open(f'{country}_prophet_base_model.json', 'w') as f:
    f.write(model_to_json(model))

In [13]:
df_future = pandas.DataFrame(data = {'ds': ['2023-03-01', '2023-04-01', '2023-05-01']})

In [14]:
df_future_prediction: pandas.DataFrame = model.predict(df_future)

In [15]:
df_future_prediction

Unnamed: 0,ds,trend,yhat_lower,yhat_upper,trend_lower,trend_upper,additive_terms,additive_terms_lower,additive_terms_upper,yearly,yearly_lower,yearly_upper,multiplicative_terms,multiplicative_terms_lower,multiplicative_terms_upper,yhat
0,2023-03-01,5.530753,4.000432,7.302066,5.530753,5.530753,0.136035,0.136035,0.136035,0.136035,0.136035,0.136035,0.0,0.0,0.0,5.666788
1,2023-04-01,5.526702,3.93433,7.353113,5.526702,5.526702,0.116599,0.116599,0.116599,0.116599,0.116599,0.116599,0.0,0.0,0.0,5.643301
2,2023-05-01,5.522781,3.881598,7.392238,5.522781,5.522781,0.102483,0.102483,0.102483,0.102483,0.102483,0.102483,0.0,0.0,0.0,5.625264
