-
Notifications
You must be signed in to change notification settings - Fork 31
/
test_selection.py
60 lines (53 loc) · 2.41 KB
/
test_selection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import numpy as np
import pytest
import psyneulink.core.components.functions.selectionfunctions as Functions
import psyneulink.core.globals.keywords as kw
import psyneulink.core.llvm as pnlvm
np.random.seed(0)
SIZE=10
test_var = np.random.rand(SIZE) * 2.0 - 1.0
test_prob = np.random.rand(SIZE)
test_data = [
(Functions.OneHot, test_var, {'mode':kw.MAX_VAL}, [0., 0., 0., 0., 0., 0., 0., 0., 0.92732552, 0.]),
(Functions.OneHot, test_var, {'mode':kw.MAX_ABS_VAL}, [0., 0., 0., 0., 0., 0., 0., 0., 0.92732552, 0.]),
(Functions.OneHot, test_var, {'mode':kw.MAX_INDICATOR}, [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.]),
(Functions.OneHot, test_var, {'mode':kw.MAX_ABS_INDICATOR}, [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.]),
(Functions.OneHot, test_var, {'mode':kw.MIN_VAL}, [0., 0., 0., 0., 0., 0., 0., 0., 0., -0.23311696]),
(Functions.OneHot, test_var, {'mode':kw.MIN_ABS_VAL}, [0., 0., 0., 0.08976637, 0., 0., 0., 0., 0., 0.]),
(Functions.OneHot, test_var, {'mode':kw.MIN_INDICATOR}, [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]),
(Functions.OneHot, test_var, {'mode':kw.MIN_ABS_INDICATOR}, [0., 0., 0., 1.,0., 0., 0., 0., 0., 0.]),
(Functions.OneHot, [test_var, test_prob], {'mode':kw.PROB}, [0.09762701, 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
(Functions.OneHot, [test_var, test_prob], {'mode':kw.PROB_INDICATOR}, [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
]
# use list, naming function produces ugly names
names = [
"OneHot MAX_VAL",
"OneHot MAX_ABS_VAL",
"OneHot MAX_INDICATOR",
"OneHot MAX_ABS_INDICATOR",
"OneHot MIN_VAL",
"OneHot MIN_ABS_VAL",
"OneHot MIN_INDICATOR",
"OneHot MIN_ABS_INDICATOR",
"OneHot PROB",
"OneHot PROB_INDICATOR",
]
GROUP_PREFIX="SelectionFunction "
@pytest.mark.function
@pytest.mark.integrator_function
@pytest.mark.benchmark
@pytest.mark.parametrize("func, variable, params, expected", test_data, ids=names)
def test_basic(func, variable, params, expected, benchmark, func_mode):
f = func(default_variable=variable, **params)
benchmark.group = GROUP_PREFIX + func.componentName + params['mode']
if func_mode == 'Python':
EX = f
elif func_mode == 'LLVM':
EX = pnlvm.execution.FuncExecution(f).execute
elif func_mode == 'PTX':
EX = pnlvm.execution.FuncExecution(f).cuda_execute
EX(variable)
res = EX(variable)
assert np.allclose(res, expected)
if benchmark.enabled:
benchmark(EX, variable)