In [126]:
import os
import sys
import math
import pickle
import logging
from pathlib import Path

import scipy as sp

from sklearn.linear_model import LinearRegression

%load_ext autoreload
%autoreload 2

import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import seaborn as sns
sns.set_context("poster")
sns.set(rc={'figure.figsize': (16, 9.)})
sns.set_style("whitegrid")

import pandas as pd
pd.set_option("display.max_rows", 120)
pd.set_option("display.max_columns", 120)

logging.basicConfig(level=logging.INFO, stream=sys.stdout)
_logger = logging.getLogger()

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [127]:
from bhm_at_scale.handler import ModelHandler
from bhm_at_scale.model import model, guide, local_guide, check_model_guide, predictive_model, Site
from bhm_at_scale.utils import summary, stats_to_df, preds_to_df

In [128]:
import jax.numpy as jnp
from jax import random, ops
from jax import lax
from jax import jit
from jax.numpy import DeviceArray
import numpy as np
import numpyro
from numpyro import optim
import numpyro.distributions as dist
from numpyro.infer import ELBO, SVI, Predictive
from numpyro.infer.svi import SVIState

In [129]:
X_train = jnp.array(np.load('../data/preprocessed/X_train.npz')['arr_0'])
X_train.shape

(1000, 942, 24)

## Fit the hierachical model

In [130]:
check_model_guide(X_train, model=model, guide=guide)
train_handler = ModelHandler(model=model, guide=guide)

In [131]:
train_handler.fit(X_train, n_epochs=5_000, log_freq=1_000, lr=0.1)

epoch:    0 loss:      114879.0547
epoch: 1000 loss:        6734.7690
epoch: 2000 loss:        6423.1729
epoch: 3000 loss:        6406.1953
epoch: 4000 loss:        6383.2485
epoch: 5000 loss:        6394.3232


6371.98291015625

In [132]:
train_handler.fit(X_train, n_epochs=1_000, log_freq=200, lr=0.001)

epoch:    0 loss:        6371.9829
epoch:  200 loss:        6369.5000
epoch:  400 loss:        6368.5176
epoch:  600 loss:        6368.0977
epoch:  800 loss:        6370.3887
epoch: 1000 loss:        6370.7646


6367.494140625

## Checkpoint: Save/restore current state

In [133]:
with open('../data/result/optim_state.pickle', 'bw') as fh:
    train_handler.dump_optim_state(fh)

In [134]:
train_handler = ModelHandler(model=model, guide=guide)
with open('../data/result/optim_state.pickle', 'br') as fh:
     train_handler.load_optim_state(fh)
# this is needed to initialize `svi`
train_handler.fit(X_train, n_epochs=100, lr=0.001)

6367.4033203125

## Predict on training set and check fitted parameters

In [135]:
pred_handler = ModelHandler(model=predictive_model(train_handler.model_params), guide=guide)
pred_handler.optim_state = train_handler.optim_state 

In [136]:
preds_samples = pred_handler.predict(X_train, return_sites=[Site.obs], num_samples=200)

In [137]:
latent_samples = train_handler.predict(X_train, return_sites=[Site.coefs, Site.coef_mus, Site.coef_sigmas], num_samples=200)

In [138]:
for site in [Site.coef_mus, Site.coef_sigmas]:
    samples_df = pd.DataFrame(latent_samples[site])
    samples_df.to_csv(f'../data/result/{site}.csv', index=False)

In [139]:
stats = summary(latent_samples, poisson=True)
df_edf = pd.read_csv('../data/preprocessed/edf.csv')
df_stats = stats_to_df(stats, df_edf.columns[2:-1])
df_stats.to_csv('../data/result/stats.csv', index=False)

In [140]:
preds = summary(preds_samples, poisson=False)
df_preds = preds_to_df(preds)
df_preds.to_csv('../data/result/train_preds.csv', index=False)

## Predict on test set with only little data

In [141]:
X_test = jnp.array(np.load('../data/preprocessed/X_test.npz')['arr_0'])
X_test.shape

(115, 942, 24)

In [142]:
known_days = 7  # consider only known days of history
X_test_known = X_test[:, :known_days, :]

### Fit on known data

In [143]:
train_local_handler = ModelHandler(model=model, guide=local_guide(train_handler.model_params))

In [144]:
train_local_handler.fit(X_test_known, n_epochs=1_000, log_freq=200, lr=0.1)

epoch:    0 loss:          68.6854
epoch:  200 loss:          49.3262
epoch:  400 loss:          49.5047
epoch:  600 loss:          50.0747
epoch:  800 loss:          50.1147
epoch: 1000 loss:          50.4854


49.72366714477539

In [145]:
train_local_handler.fit(X_test_known, n_epochs=1_000, log_freq=200, lr=0.001)

epoch:    0 loss:          49.7237
epoch:  200 loss:          48.7054
epoch:  400 loss:          48.2247
epoch:  600 loss:          48.0160
epoch:  800 loss:          48.2504
epoch: 1000 loss:          48.5166


47.84402847290039

### Predict future of test data

In [146]:
params = train_handler.model_params
params.update(train_local_handler.model_params)
pred_local_handler = ModelHandler(model=predictive_model(params), guide=local_guide(params))
pred_local_handler.optim_state = train_local_handler.optim_state 

In [147]:
preds_samples = pred_local_handler.predict(X_test, return_sites=[Site.obs], num_samples=200)

In [148]:
preds = summary(preds_samples, poisson=False)
df_preds = preds_to_df(preds).assign(StoreId=lambda df: df.StoreId + 1000)
df_preds.to_csv('../data/result/test_preds.csv', index=False)

### Compare with conventional Poisson regression using Scikit-Learn

In [193]:
reg = LinearRegression()

In [211]:
# select a single store_id
store_id = 16
X = np.nan_to_num(X_test_known, nan=1.0)[store_id, ...]
X, y = X[:, :-1], X[:, -1]

In [232]:
# we fit on the log-transformed target to achieve a multiplicate relationship
reg.fit(X, np.log(y))

LinearRegression()

In [198]:
# high overfit since we have more features than target values
np.exp(reg.predict(X)) - y

array([-4.7683716e-07, -7.8125000e-03, -4.3945312e-03,  9.7656250e-04,
       -5.3710938e-03,  9.5367432e-07, -2.9296875e-03], dtype=float32)

In [199]:
# no overfitting in case of the Bayesian model
jnp.mean(preds_samples[Site.obs], axis=0)[store_id][:known_days] - y

DeviceArray([5639.725  , -195.73486, -948.5249 , -158.41504,  -18.72998,
             5615.995  ,  186.97461], dtype=float32)

### Compare the coefficients of conventional regression to the hierarchical model

In [245]:
# for many feature there is no meaningful value, i.e. 0, since they were not encountered in training
reg.coef_

array([ 1.1747565, -2.2386408,  2.144292 ,  1.9889385,  1.9385108,
        1.8024142, -6.8102694,  1.1747564,  2.2386408, -2.2386408,
        0.       ,  0.       , -0.0943485,  0.       ,  0.       ,
        0.       ,  0.       ,  0.       ,  0.       ,  0.       ,
        0.       ,  0.       ,  0.       ], dtype=float32)

In [238]:
# using the global prior it's possible to derive meaningful values
coefs_samples = pred_local_handler.predict(X_test_known, return_sites=[Site.coefs], num_samples=200)
print(jnp.mean(coefs_samples[Site.coefs], axis=0)[store_id])

[2.7901876  2.6591206  2.688186   2.6126776  2.6567733  2.5543957
 2.4938378  0.33227593 3.1156936  2.9264405  2.6920953  2.9548461
 0.05613961 0.06542101 2.8379257  2.9023957  3.5701392  3.207435
 4.056985   2.9304533  2.7463422  2.8231895  2.9590063 ]


## Now compare those coefficients to the ones fitted on the whole time-series

In [228]:
all_local_handler = ModelHandler(model=model, guide=local_guide(train_handler.model_params))
all_local_handler.fit(X_test[store_id:store_id+1], n_epochs=10_000, log_freq=1_000, lr=0.001)

epoch:     0 loss:        7813.2842
epoch:  1000 loss:        7226.6973
epoch:  2000 loss:        7251.3184
epoch:  3000 loss:        7163.0728
epoch:  4000 loss:        7166.2578
epoch:  5000 loss:        7116.2334
epoch:  6000 loss:        7119.3320
epoch:  7000 loss:        7191.9634
epoch:  8000 loss:        7127.7339
epoch:  9000 loss:        7166.2593
epoch: 10000 loss:        7137.1484


7158.21337890625

In [239]:
# many coefficients are really similar but mind the log-space!
all_coefs_samples = all_local_handler.predict(X_test[store_id:store_id+1], return_sites=[Site.coefs], num_samples=200)
print(jnp.mean(all_coefs_samples[Site.coefs], axis=0)[0])

[ 2.8149958   2.7551858   2.618216    2.6453686   2.6726904   2.5201018
  2.5930352   0.22652954  3.1844485   3.1163406   2.542964    2.9477344
 -0.03218627  0.06836596  2.8726478   2.925491    3.5667892   3.2158158
  4.0523424   2.9164746   2.7241354   2.8247733   2.9598224 ]
