# Speeding up PyStan demos on Google Colab with model caching

To speed up demos with caching, user needs to upload compiled models to github.

In [1]:
# import used libraries
import os
from stan_colab_utils import StanModel, install

In [3]:
install("pystan")
import pystan

In [4]:
# Paths
eight_schools_path =            "models/eight_schools/8schools.stan"
eight_schools_data_path =       "models/eight_schools/8schools.data.R"
eight_schools_model_cache =     "models/eight_schools/8schools_model.gz"
eight_schools_model_fit_cache = "models/eight_schools/8schools_model_fit.gz"

## 8-schools model

In [5]:
with open(eight_schools_path, "r") as f:
    print(f.read())

data {
    int<lower=0> J; // number of schools
    vector[J] y; // estimated treatment effects
    vector<lower=0>[J] sigma; // s.e. of effect estimates
}
parameters {
    real mu;
    real<lower=0> tau;
    vector[J] eta;
}
transformed parameters {
    vector[J] theta;
    theta = mu + tau * eta;
}
model {
    eta ~ normal(0, 1);
    y ~ normal(theta, sigma);
}


## 8-schools data

In [6]:
stan_data = pystan.read_rdump(eight_schools_data_path)
stan_data

OrderedDict([('J', array(8)),
             ('y', array([28,  8, -3,  7, -1,  1, 18, 12])),
             ('sigma', array([15, 10, 16, 11,  9, 11, 10, 18]))])

In [7]:
stan_seed = 2018

### Compile model and sample

In [8]:
%%time
%time stan_model = pystan.StanModel(file=eight_schools_path)
%time fit = stan_model.sampling(data=stan_data, seed=stan_seed)

INFO:pystan:COMPILING THE C++ CODE FOR MODEL anon_model_b9d9d4eb1cc460053ac2b8a9dd68028e NOW.


Wall time: 41.9 s
Wall time: 1.62 s
Wall time: 43.5 s


In [12]:
print(fit)

Inference for Stan model: anon_model_b9d9d4eb1cc460053ac2b8a9dd68028e.
4 chains, each with iter=2000; warmup=1000; thin=1; 
post-warmup draws per chain=1000, total post-warmup draws=4000.

           mean se_mean     sd   2.5%    25%    50%    75%  97.5%  n_eff   Rhat
mu         8.07    0.12   5.22  -2.18   4.84   8.03  11.25  18.29   2044    1.0
tau        6.43    0.14   5.39   0.27   2.42   5.09   8.98  20.16   1470    1.0
eta[1]     0.37    0.02   0.94  -1.49  -0.25    0.4   1.03   2.13   3701    1.0
eta[2]  -3.7e-3    0.01   0.91  -1.83   -0.6  -0.02    0.6   1.82   3940    1.0
eta[3]    -0.18    0.01   0.93  -2.04  -0.78  -0.16   0.43   1.69   4364    1.0
eta[4]    -0.05    0.01   0.87  -1.75  -0.61  -0.05   0.51   1.71   4077    1.0
eta[5]    -0.32    0.01   0.85  -1.97  -0.88  -0.34   0.23    1.4   3565    1.0
eta[6]    -0.22    0.01   0.89  -1.97  -0.81  -0.24   0.38   1.57   4363    1.0
eta[7]     0.32    0.01   0.88  -1.46  -0.24   0.33   0.91   1.98   4159    1.0
eta[8]     

In [14]:
import os
from stan_colab_utils import save, read
if not os.path.exists(eight_schools_model_cache):
    save(eight_schools_model_cache, stan_model)
if not os.path.exists(eight_schools_model_fit_cache):
    save(eight_schools_model_fit_cache, [stan_model, fit])

### Read file from cache (or compile if not found) and sample

In [15]:
%%time
%time stan_model2 = StanModel(file=eight_schools_path, cache_path=eight_schools_model_cache)
%time fit2 = stan_model2.sampling(data=stan_data, seed=stan_seed)

Wall time: 50 ms
Wall time: 1.53 s
Wall time: 1.58 s


In [16]:
print(fit2)

Inference for Stan model: anon_model_b9d9d4eb1cc460053ac2b8a9dd68028e.
4 chains, each with iter=2000; warmup=1000; thin=1; 
post-warmup draws per chain=1000, total post-warmup draws=4000.

           mean se_mean     sd   2.5%    25%    50%    75%  97.5%  n_eff   Rhat
mu         8.07    0.12   5.22  -2.18   4.84   8.03  11.25  18.29   2044    1.0
tau        6.43    0.14   5.39   0.27   2.42   5.09   8.98  20.16   1470    1.0
eta[1]     0.37    0.02   0.94  -1.49  -0.25    0.4   1.03   2.13   3701    1.0
eta[2]  -3.7e-3    0.01   0.91  -1.83   -0.6  -0.02    0.6   1.82   3940    1.0
eta[3]    -0.18    0.01   0.93  -2.04  -0.78  -0.16   0.43   1.69   4364    1.0
eta[4]    -0.05    0.01   0.87  -1.75  -0.61  -0.05   0.51   1.71   4077    1.0
eta[5]    -0.32    0.01   0.85  -1.97  -0.88  -0.34   0.23    1.4   3565    1.0
eta[6]    -0.22    0.01   0.89  -1.97  -0.81  -0.24   0.38   1.57   4363    1.0
eta[7]     0.32    0.01   0.88  -1.46  -0.24   0.33   0.91   1.98   4159    1.0
eta[8]     

## Read sampling results

In [17]:
%%time
# For slow models, it might be a good idea to save presampled data
%time stan_model3, fit3 = read(eight_schools_model_fit_cache)

Wall time: 59 ms
Wall time: 62 ms


In [18]:
print(fit3)

Inference for Stan model: anon_model_b9d9d4eb1cc460053ac2b8a9dd68028e.
4 chains, each with iter=2000; warmup=1000; thin=1; 
post-warmup draws per chain=1000, total post-warmup draws=4000.

           mean se_mean     sd   2.5%    25%    50%    75%  97.5%  n_eff   Rhat
mu         8.07    0.12   5.22  -2.18   4.84   8.03  11.25  18.29   2044    1.0
tau        6.43    0.14   5.39   0.27   2.42   5.09   8.98  20.16   1470    1.0
eta[1]     0.37    0.02   0.94  -1.49  -0.25    0.4   1.03   2.13   3701    1.0
eta[2]  -3.7e-3    0.01   0.91  -1.83   -0.6  -0.02    0.6   1.82   3940    1.0
eta[3]    -0.18    0.01   0.93  -2.04  -0.78  -0.16   0.43   1.69   4364    1.0
eta[4]    -0.05    0.01   0.87  -1.75  -0.61  -0.05   0.51   1.71   4077    1.0
eta[5]    -0.32    0.01   0.85  -1.97  -0.88  -0.34   0.23    1.4   3565    1.0
eta[6]    -0.22    0.01   0.89  -1.97  -0.81  -0.24   0.38   1.57   4363    1.0
eta[7]     0.32    0.01   0.88  -1.46  -0.24   0.33   0.91   1.98   4159    1.0
eta[8]     

# Environment

In [31]:
import sys
from datetime import datetime
print("Notebook date:", datetime.now().date(), "\n")
for tool in [sys, np, pystan]:
    if tool.__name__ == 'sys':
        print("python", tool.version)
    else:
        print(tool.__name__, tool.__version__)

Notebook date: 2018-08-28 

python 3.7.0 | packaged by conda-forge | (default, Jul 13 2018, 23:54:23) [MSC v.1900 64 bit (AMD64)]
numpy 1.15.0
pystan 2.18.0.0
