forked from grburgess/stan_utility
/
test_compile.py
69 lines (54 loc) · 2.42 KB
/
test_compile.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
import numpy as np
import tempfile
import joblib
def test_compile_file():
import stan_utility.cache
with tempfile.TemporaryDirectory() as cachedir:
print("using cachedir:", cachedir)
stan_utility.cache.path = cachedir
stan_utility.cache.mem = joblib.Memory(cachedir, verbose=False)
import stan_utility
model = stan_utility.compile_model(os.path.join(os.path.dirname(__file__), 'test.stan'))
data = dict(
mean=1,
unused=np.random.normal(size=(4,42)),
)
stan_utility.sample_model(model, data, chains=2)
files = os.listdir(stan_utility.cache.get_path())
assert "joblib" in files
assert any(f for f in files if f.startswith("cached-") and f.endswith('.pkl')), files
assert len(files) > 1, files
stan_utility.cache.clear()
files = os.listdir(stan_utility.cache.get_path())
assert files == ["joblib"], files
def test_compile_string():
import stan_utility.cache
with tempfile.TemporaryDirectory() as cachedir:
print("using cachedir:", cachedir)
stan_utility.cache.path = cachedir
stan_utility.cache.mem = joblib.Memory(cachedir, verbose=False)
import stan_utility
model_code = open(os.path.join(os.path.dirname(__file__), 'test.stan')).read()
model = stan_utility.compile_model_code(model_code, model_name="mytest")
data = dict(
mean=1,
unused=np.random.normal(size=(4,42)),
)
if os.path.exists("mytest_fitfit.hdf5"):
os.unlink("mytest_fitfit.hdf5")
samples = stan_utility.sample_model(model, data, outprefix="mytest_fit", chains=2, iter=346)
assert os.path.exists("mytest_fitfit.hdf5")
os.unlink("mytest_fitfit.hdf5")
if os.path.exists("mytest_fit_corner.pdf"):
os.unlink("mytest_fit_corner.pdf")
stan_utility.plot_corner(samples, outprefix="mytest_fit")
assert os.path.exists("mytest_fit_corner.pdf")
os.unlink("mytest_fit_corner.pdf")
flat_samples = stan_utility.get_flat_posterior(samples)
assert set(flat_samples.keys()) == {"x", "y"}, flat_samples.keys()
assert flat_samples['x'].shape == (346,), flat_samples['x'].shape
assert flat_samples['y'].shape == (346, 10), flat_samples['y'].shape
if __name__ == '__main__':
test_compile_string()
test_compile_file()