forked from PGM-Lab/InferPy
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_runtime.py
66 lines (49 loc) · 1.65 KB
/
test_runtime.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import pytest
import tensorflow as tf
import inferpy as inf
from inferpy.util import runtime
@pytest.fixture(params=[True, False])
def is_tf_run_default(request):
inf.set_tf_run(request.param)
yield request.param
def test_tf_run_allowed(is_tf_run_default):
# create a function which directly returns a tensor
@runtime.tf_run_allowed
def foo():
return tf.ones(1)
if is_tf_run_default:
# in this case the tensor is run in a session
assert foo() == [1.]
else:
# in this case the function returns a tensor
assert isinstance(foo(), tf.Tensor)
def test_tf_run_allowed_nested(is_tf_run_default):
# create a function which directly returns a tensor created by a second function
# store the result of the nested call function bar
res_bar = None
@runtime.tf_run_allowed
def foo():
return bar()
@runtime.tf_run_allowed
def bar():
nonlocal res_bar
res_bar = tf.ones(1)
return res_bar
if is_tf_run_default:
# in this case the tensor is run in a session
assert foo() == [1.]
else:
# in this case the function returns a tensor
assert isinstance(foo(), tf.Tensor)
# independently of is_tf_run_default, res_bar is always a tensor
assert isinstance(res_bar, tf.Tensor)
def test_tf_run_ignored(is_tf_run_default):
# create a function which directly returns a tensor
@runtime.tf_run_ignored
def foo():
return bar()
@runtime.tf_run_allowed
def bar():
return tf.ones(1)
# independently of is_tf_run_default, foo always returns a tensor
assert isinstance(foo(), tf.Tensor)