In [1]:
import os
import sys
import math
import pickle
import logging
from pathlib import Path

import scipy as sp

from sklearn.linear_model import LinearRegression

%load_ext autoreload
%autoreload 2

import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import seaborn as sns
sns.set_context("poster")
sns.set(rc={'figure.figsize': (16, 9.)})
sns.set_style("whitegrid")

import pandas as pd
pd.set_option("display.max_rows", 120)
pd.set_option("display.max_columns", 120)

logging.basicConfig(level=logging.INFO, stream=sys.stdout)
_logger = logging.getLogger()

In [2]:
from bhm_at_scale.handler import ModelHandler
from bhm_at_scale.model import model, guide, local_guide, check_model_guide, predictive_model, Site
from bhm_at_scale.utils import summary, stats_to_df, preds_to_df

In [3]:
import jax.numpy as jnp
from jax import random, ops
from jax import lax
from jax import jit
from jax.numpy import DeviceArray
import numpy as np
import numpyro
from numpyro import optim
import numpyro.distributions as dist
from numpyro.infer import ELBO, SVI, Predictive
from numpyro.infer.svi import SVIState

In [4]:
X_train = jnp.array(np.load('../data/preprocessed/X_train.npz')['arr_0'])
X_train.shape

(1000, 942, 24)

## Fit the hierachical model

In [12]:
check_model_guide(X_train, model=model, guide=guide)
train_handler = ModelHandler(model=model, guide=guide)

In [13]:
train_handler.fit(X_train, n_epochs=5_000, log_freq=1_000, lr=0.1)

epoch:    0 loss:      114879.0703
epoch: 1000 loss:        6734.7886
epoch: 2000 loss:        6423.1748
epoch: 3000 loss:        6406.2109
epoch: 4000 loss:        6383.2534
epoch: 5000 loss:        6394.3198


6371.9638671875

In [7]:
train_handler.fit(X_train, n_epochs=1_000, log_freq=200, lr=0.001)

epoch:    0 loss:              nan
epoch:  200 loss:              nan
epoch:  400 loss:              nan
epoch:  600 loss:              nan
epoch:  800 loss:              nan
epoch: 1000 loss:              nan


nan

## Checkpoint: Save/restore current state

In [14]:
with open('../data/result/optim_state.pickle', 'bw') as fh:
    train_handler.dump_optim_state(fh)

In [15]:
train_handler = ModelHandler(model=model, guide=guide)
with open('../data/result/optim_state.pickle', 'br') as fh:
     train_handler.load_optim_state(fh)
# this is needed to initialize `svi`
train_handler.fit(X_train, n_epochs=100, lr=0.001)

6370.47119140625

## Predict on training set and check fitted parameters

In [16]:
pred_handler = ModelHandler(model=predictive_model(train_handler.model_params), guide=guide)
pred_handler.optim_state = train_handler.optim_state 

In [17]:
preds_samples = pred_handler.predict(X_train, return_sites=[Site.days], num_samples=200)

In [None]:
latent_samples = train_handler.predict(X_train, return_sites=[Site.coefs, Site.coef_mus, Site.coef_sigmas], num_samples=200)

In [None]:
for site in [Site.coef_mus, Site.coef_sigmas]:
    samples_df = pd.DataFrame(latent_samples[site])
    samples_df.to_csv(f'../data/result/{site}.csv', index=False)

In [None]:
stats = summary(latent_samples, poisson=True)
df_edf = pd.read_csv('../data/preprocessed/edf.csv')
df_stats = stats_to_df(stats, df_edf.columns[2:-1])
df_stats.to_csv('../data/result/stats.csv', index=False)

In [None]:
preds = summary(preds_samples, poisson=False)
df_preds = preds_to_df(preds[Site.days])
df_preds.to_csv('../data/result/train_preds.csv', index=False)

## Predict on test set with only little data

In [None]:
X_test = jnp.array(np.load('../data/preprocessed/X_test.npz')['arr_0'])
X_test.shape

In [None]:
known_days = 7  # consider only known days of history
X_test_known = X_test[:, :known_days, :]

### Fit on known data

In [None]:
train_local_handler = ModelHandler(model=model, guide=local_guide(train_handler.model_params))

In [None]:
train_local_handler.fit(X_test_known, n_epochs=1_000, log_freq=200, lr=0.1)

In [None]:
train_local_handler.fit(X_test_known, n_epochs=1_000, log_freq=200, lr=0.001)

### Predict future of test data

In [None]:
params = train_handler.model_params
params.update(train_local_handler.model_params)
pred_local_handler = ModelHandler(model=predictive_model(params), guide=local_guide(params))
pred_local_handler.optim_state = train_local_handler.optim_state 

In [None]:
preds_samples = pred_local_handler.predict(X_test, return_sites=[Site.days], num_samples=200)

In [None]:
preds = summary(preds_samples, poisson=False)
df_preds = preds_to_df(preds[Site.days]).assign(StoreId=lambda df: df.StoreId + 1000)
df_preds.to_csv('../data/result/test_preds.csv', index=False)

### Compare with conventional Poisson regression using Scikit-Learn

In [None]:
reg = LinearRegression()

In [None]:
# select a single store_id
store_id = 16
X = np.nan_to_num(X_test_known, nan=1.0)[store_id, ...]
X, y = X[:, :-1], X[:, -1]

In [None]:
# we fit on the log-transformed target to achieve a multiplicate relationship
reg.fit(X, np.log(y))

In [None]:
# high overfit since we have more features than target values
np.exp(reg.predict(X)) - y

In [None]:
# no overfitting in case of the Bayesian model
jnp.mean(preds_samples[Site.days], axis=0)[store_id][:known_days] - y

### Compare the coefficients of conventional regression to the hierarchical model

In [None]:
# for many feature there is no meaningful value, i.e. 0, since they were not encountered in training
print(reg.coef_)

In [None]:
# using the global prior it's possible to derive meaningful values
coefs_samples = pred_local_handler.predict(X_test_known, return_sites=[Site.coefs], num_samples=200)
print(jnp.mean(coefs_samples[Site.coefs], axis=0)[store_id])

## Now compare those coefficients to the ones fitted on the whole time-series

In [None]:
all_local_handler = ModelHandler(model=model, guide=local_guide(train_handler.model_params))
all_local_handler.fit(X_test[store_id:store_id+1], n_epochs=10_000, log_freq=1_000, lr=0.001)

In [None]:
# many coefficients are really similar but mind the log-space!
all_coefs_samples = all_local_handler.predict(X_test[store_id:store_id+1], return_sites=[Site.coefs], num_samples=200)
print(jnp.mean(all_coefs_samples[Site.coefs], axis=0)[0])