## Notebook Configuration && Imports

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import itertools
import logging
from tqdm import tqdm

import pandas as pd
from prophet import Prophet

from storesales.constants import (
    SUBMISSIONS_PATH,
    EXTERNAL_TRAIN_PATH,
    EXTERNAL_SAMPLE_SUBMISSION_PATH,
    EXTERNAL_TEST_PATH,
)

## Load && Prepare Data

In [3]:
original_train_df = pd.read_csv(EXTERNAL_TRAIN_PATH, parse_dates=["date"])
original_train_df.sort_values(by=["date", "store_nbr", "family"], inplace=True)

original_test_df = pd.read_csv(EXTERNAL_TEST_PATH, parse_dates=["date"])

sample_submission_df = pd.read_csv(EXTERNAL_SAMPLE_SUBMISSION_PATH, index_col="id")

In [4]:
train_period = original_train_df.index.unique()

## Facebook Prophet

In [6]:
train_df = original_train_df[["date", "store_nbr", "family", "sales"]].copy()
train_df.rename(columns={"date": "ds", "sales": "y"}, inplace=True)

test_df = original_test_df.rename(columns={"date": "ds"})

prophet_submission_df = sample_submission_df.copy()

In [7]:
train_groups = train_df.groupby(["store_nbr", "family"])
test_groups = test_df.groupby(["store_nbr", "family"])

In [8]:
stores = train_df["store_nbr"].unique()
families = train_df["family"].unique()

In [9]:
groups = itertools.product(stores, families)
total = len(stores) * len(families) - 1

In [None]:
logging.getLogger("prophet").setLevel(logging.ERROR)
logging.getLogger("cmdstanpy").setLevel(logging.ERROR)

In [12]:
for store_nbt_to_family in tqdm(groups, total=total):
    train_group = train_groups.get_group(store_nbt_to_family)
    test_group = test_groups.get_group(store_nbt_to_family).reset_index(drop=True)

    m = Prophet()
    m.fit(train_group)

    forecast = m.predict(test_group)

    test_group["yhat"] = forecast["yhat"]

    test_group.set_index("id", inplace=True)
    prophet_submission_df.loc[test_group.index, "sales"] = test_group["yhat"]

 99%|█████████▉| 1771/1782 [06:59<00:02,  4.22it/s]


In [13]:
prophet_submission_file = "prophet_submission.csv"
prophet_submission_df.to_csv(os.path.join(SUBMISSIONS_PATH, prophet_submission_file))