# Speeding up PyStan demos on Google Colab with model caching

To speed up demos with caching, user needs to upload compiled models to github.

In [1]:
!git clone https://github.com/ahartikainen/PyStan_Google_Colab_Demo
import os
os.chdir("PyStan_Google_Colab_Demo")
!git fetch --all
!git reset --hard origin/master

Cloning into 'PyStan_Google_Colab_Demo'...
remote: Counting objects: 34, done.[K
remote: Compressing objects: 100% (26/26), done.[K
remote: Total 34 (delta 13), reused 23 (delta 6), pack-reused 0[K
Unpacking objects: 100% (34/34), done.
Fetching origin
HEAD is now at e1eeb07 add read and save


In [2]:
!ls

example_cache.ipynb  LICENSE  models  README.md  stan_colab_utils.py


In [0]:
from stan_colab_utils import StanModel, install, read, save

In [4]:
install("pystan")
import pystan

Collecting pystan
  Using cached https://files.pythonhosted.org/packages/46/37/801a5a932e7f1f038542e7c5e4c4010aac19a26ea6bde9534505465f8c8c/pystan-2.17.1.0-cp36-cp36m-manylinux1_x86_64.whl
Installing collected packages: pystan
Successfully installed pystan-2.17.1.0


In [0]:
# Paths
eight_schools_path =            "models/eight_schools/8schools.stan"
eight_schools_data_path =       "models/eight_schools/8schools.data.R"
eight_schools_model_cache =     "models/eight_schools/8schools_model.gz"
eight_schools_model_fit_cache = "models/eight_schools/8schools_model_fit.gz"

## 8-schools model

In [6]:
with open(eight_schools_path, "r") as f:
    print(f.read())

data {
    int<lower=0> J; // number of schools
    vector[J] y; // estimated treatment effects
    vector<lower=0>[J] sigma; // s.e. of effect estimates
}
parameters {
    real mu;
    real<lower=0> tau;
    vector[J] eta;
}
transformed parameters {
    vector[J] theta;
    theta = mu + tau * eta;
}
model {
    eta ~ normal(0, 1);
    y ~ normal(theta, sigma);
}


## 8-schools data

In [7]:
stan_data = pystan.read_rdump(eight_schools_data_path)
stan_data

OrderedDict([('J', array(8)),
             ('y', array([28,  8, -3,  7, -1,  1, 18, 12])),
             ('sigma', array([15, 10, 16, 11,  9, 11, 10, 18]))])

In [0]:
stan_seed = 2018

### Compile model and sample

In [9]:
%%time
%time stan_model = pystan.StanModel(file=eight_schools_path)
%time fit = stan_model.sampling(data=stan_data, seed=stan_seed)

COMPILING THE C++ CODE FOR MODEL anon_model_b9d9d4eb1cc460053ac2b8a9dd68028e NOW.
CPU times: user 1.1 s, sys: 208 ms, total: 1.31 s
Wall time: 1min 3s
CPU times: user 10.7 ms, sys: 39.3 ms, total: 50 ms
Wall time: 357 ms
CPU times: user 1.11 s, sys: 249 ms, total: 1.36 s
Wall time: 1min 4s


In [10]:
print(fit)

Inference for Stan model: anon_model_b9d9d4eb1cc460053ac2b8a9dd68028e.
4 chains, each with iter=2000; warmup=1000; thin=1; 
post-warmup draws per chain=1000, total post-warmup draws=4000.

           mean se_mean     sd   2.5%    25%    50%    75%  97.5%  n_eff   Rhat
mu         8.03    0.11   5.12   -1.8   4.68   8.01  11.21   18.4   2243    1.0
tau        6.63    0.14   5.46   0.26   2.48   5.45   9.28  20.87   1557    1.0
eta[0]     0.39    0.01   0.95  -1.52  -0.24   0.41   1.05   2.14   4000    1.0
eta[1]    -0.02    0.01   0.87  -1.76  -0.59  -0.01   0.55   1.66   4000    1.0
eta[2]    -0.19    0.01   0.93  -2.02  -0.81  -0.18   0.43   1.66   4000    1.0
eta[3]    -0.04    0.01   0.88  -1.76  -0.61  -0.05   0.53   1.71   4000    1.0
eta[4]    -0.37    0.01   0.85   -2.1  -0.94  -0.36   0.18   1.32   4000    1.0
eta[5]    -0.23    0.01   0.89  -1.99  -0.81  -0.23   0.34   1.58   4000    1.0
eta[6]     0.34    0.01   0.88  -1.45   -0.2   0.36   0.94   2.05   4000    1.0
eta[7]     

#### Update cache - Owner only [optional]

In [0]:
# To update git models run, change to True
if False:
    from stan_colab_utils import save, read
    save(eight_schools_model_cache, stan_model)
    save(eight_schools_model_fit_cache, [stan_model, fit])
    
    # To upload compiled models: 
    # download them to your local machine
    # add to the correct git path --> add + push them to git
    # Left menu --> files --> update
    

### Read file from cache (or compile if not found) and sample

In [12]:
%%time
%time stan_model2 = StanModel(file=eight_schools_path, cache_path=eight_schools_model_cache)
%time fit2 = stan_model2.sampling(data=stan_data, seed=stan_seed)

CPU times: user 156 ms, sys: 195 ms, total: 351 ms
Wall time: 352 ms
CPU times: user 12 ms, sys: 29.6 ms, total: 41.6 ms
Wall time: 350 ms
CPU times: user 170 ms, sys: 225 ms, total: 395 ms
Wall time: 705 ms


In [13]:
print(fit2)

Inference for Stan model: anon_model_b9d9d4eb1cc460053ac2b8a9dd68028e.
4 chains, each with iter=2000; warmup=1000; thin=1; 
post-warmup draws per chain=1000, total post-warmup draws=4000.

           mean se_mean     sd   2.5%    25%    50%    75%  97.5%  n_eff   Rhat
mu         8.03    0.11   5.12   -1.8   4.68   8.01  11.21   18.4   2243    1.0
tau        6.63    0.14   5.46   0.26   2.48   5.45   9.28  20.87   1557    1.0
eta[0]     0.39    0.01   0.95  -1.52  -0.24   0.41   1.05   2.14   4000    1.0
eta[1]    -0.02    0.01   0.87  -1.76  -0.59  -0.01   0.55   1.66   4000    1.0
eta[2]    -0.19    0.01   0.93  -2.02  -0.81  -0.18   0.43   1.66   4000    1.0
eta[3]    -0.04    0.01   0.88  -1.76  -0.61  -0.05   0.53   1.71   4000    1.0
eta[4]    -0.37    0.01   0.85   -2.1  -0.94  -0.36   0.18   1.32   4000    1.0
eta[5]    -0.23    0.01   0.89  -1.99  -0.81  -0.23   0.34   1.58   4000    1.0
eta[6]     0.34    0.01   0.88  -1.45   -0.2   0.36   0.94   2.05   4000    1.0
eta[7]     

## Read sampling results

In [14]:
%%time
# For slow models, it might be a good idea to save presampled data
%time stan_model3, fit3 = read(eight_schools_model_fit_cache)

CPU times: user 170 ms, sys: 62.7 ms, total: 233 ms
Wall time: 234 ms
CPU times: user 172 ms, sys: 62.7 ms, total: 234 ms
Wall time: 236 ms


In [15]:
print(fit3)

Inference for Stan model: anon_model_b9d9d4eb1cc460053ac2b8a9dd68028e.
4 chains, each with iter=2000; warmup=1000; thin=1; 
post-warmup draws per chain=1000, total post-warmup draws=4000.

           mean se_mean     sd   2.5%    25%    50%    75%  97.5%  n_eff   Rhat
mu         8.03    0.11   5.12   -1.8   4.68   8.01  11.21   18.4   2243    1.0
tau        6.63    0.14   5.46   0.26   2.48   5.45   9.28  20.87   1557    1.0
eta[0]     0.39    0.01   0.95  -1.52  -0.24   0.41   1.05   2.14   4000    1.0
eta[1]    -0.02    0.01   0.87  -1.76  -0.59  -0.01   0.55   1.66   4000    1.0
eta[2]    -0.19    0.01   0.93  -2.02  -0.81  -0.18   0.43   1.66   4000    1.0
eta[3]    -0.04    0.01   0.88  -1.76  -0.61  -0.05   0.53   1.71   4000    1.0
eta[4]    -0.37    0.01   0.85   -2.1  -0.94  -0.36   0.18   1.32   4000    1.0
eta[5]    -0.23    0.01   0.89  -1.99  -0.81  -0.23   0.34   1.58   4000    1.0
eta[6]     0.34    0.01   0.88  -1.45   -0.2   0.36   0.94   2.05   4000    1.0
eta[7]     

# Environment

In [16]:
import sys
import numpy
import cython
from datetime import datetime
print("Notebook date:", datetime.now().date(), "\n")
for tool in [sys, numpy, cython, pystan]:
    if tool.__name__ == 'sys':
        print("python", tool.version)
    else:
        print(tool.__name__, tool.__version__)

Notebook date: 2018-08-28 

python 3.6.3 (default, Oct  3 2017, 21:45:48) 
[GCC 7.2.0]
numpy 1.14.5
cython 0.28.5
pystan 2.17.1.0
