In [1]:
import sys
sys.path.append('..')
from src.utilities import *

import warnings
warnings.filterwarnings("ignore")

import logging, sys
logging.disable(sys.maxsize)

from fbprophet import Prophet
from fbprophet.plot import plot_plotly

## 1. Prophet feature extraction

### 1.1 Extract predictions for each month-brand-region using previous information

In [2]:
sales_train = pd.read_csv(os.path.join(raw_path, 'sales_train.csv'))

def extract_prophet_features(raw_sales: pd.DataFrame) -> pd.DataFrame:
    '''Provides prediction, additive term and trend term for every month-region-brand'''
    
    months_to_predict = ['2020-06', '2020-07', '2020-08', '2020-09', '2020-10', '2020-11', '2020-12', '2021-01',
                         '2021-02', '2021-03', '2021-04', '2021-05', '2021-06', '2021-07', '2021-08']

    brands_to_predict = ['brand_12_market', 'brand_3', 'brand_3_market']

    results_prophet = pd.DataFrame()

    for reg in tqdm(raw_sales.region.unique()):
        for brand in brands_to_predict:
            for mon in months_to_predict:

                current = raw_sales[(raw_sales.brand == brand) & 
                                      (raw_sales.region == reg) & 
                                      (raw_sales.month < mon)][['month', 'sales']]
                current.columns = ['ds', 'y']
                current.ds = pd.to_datetime(current.ds)

                m = Prophet(yearly_seasonality=True, weekly_seasonality=False, daily_seasonality=False)

                m.fit(current, iter = 60)

                prediction = m.predict(pd.DataFrame({'ds': [mon]}))

                to_append = pd.DataFrame({'month': [mon],
                                          'brand': [brand],
                                          'region': [reg],
                                          'trend': [prediction.trend[0]],
                                          'prediction': [prediction.yhat[0]],
                                          'additive_terms': [prediction.additive_terms[0]]
                                         })

                results_prophet = pd.concat([results_prophet, to_append])
                
    trend_features = results_prophet.pivot(index = ['month', 'region'], columns = 'brand', values = 'trend').reset_index()
    trend_features.columns = ['month', 'region', 'prophet_b12_market_trend', 'prophet_b3_trend', 'prophet_b3_market_trend']

    yhat_features = results_prophet.pivot(index = ['month', 'region'], columns = 'brand', values = 'prediction').reset_index()
    yhat_features.columns = ['month', 'region', 'prophet_b12_market_yhat', 'prophet_b3_yhat', 'prophet_b3_market_yhat']

    additive_features = results_prophet.pivot(index = ['month', 'region'], columns = 'brand', 
                                              values = 'additive_terms').reset_index()
    additive_features.columns = ['month', 'region', 'prophet_b12_market_add', 'prophet_b3_add', 'prophet_b3_market_add']
    
    prophet_features = trend_features.merge(yhat_features, how = 'left', on = ['month', 'region'])
    prophet_features = prophet_features.merge(additive_features, how = 'left', on = ['month', 'region'])
    prophet_features.reset_index(drop = True, inplace = True)
    
    return prophet_features

prophet_features = extract_prophet_features(sales_train)
prophet_features.head()

100%|██████████████████████████████████████████████████████████████████████████████| 201/201 [5:21:32<00:00, 95.98s/it]


Unnamed: 0,month,region,prophet_b12_market_trend,prophet_b3_trend,prophet_b3_market_trend,prophet_b12_market_yhat,prophet_b3_yhat,prophet_b3_market_yhat,prophet_b12_market_add,prophet_b3_add,prophet_b3_market_add
0,2020-06,region_0,3352784.0,158064.647452,-2884394.0,6993132.0,-503312.294078,-25888420.0,3640348.0,-661376.94153,-23004020.0
1,2020-06,region_1,56272.73,138787.230563,-3261530.0,-719883.8,201199.050187,-4020387.0,-776156.5,62411.819624,-758856.6
2,2020-06,region_10,121806.0,-119918.314067,-1200203.0,-1873387.0,432587.526283,-10074190.0,-1995193.0,552505.840349,-8873984.0
3,2020-06,region_100,-347997.8,-30921.792572,2572503.0,-5387159.0,203836.797751,6469786.0,-5039161.0,234758.590323,3897283.0
4,2020-06,region_101,1692713.0,15590.274969,367181.6,144102.5,77878.896363,3885137.0,-1548610.0,62288.621394,3517955.0


### 1.2 Store prophet features

In [3]:
prophet_features.to_csv(os.path.join(interim_path, 'prophet_features.csv'), index = False)