-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest_auto_param.py
28 lines (21 loc) · 912 Bytes
/
test_auto_param.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from unittest import TestCase
from hyperparameter import auto_param, param_scope
class TestAutoParam(TestCase):
def test_auto_param_func(self):
@auto_param("foo")
def foo(a, b=1, c=2.0, d=False, e="str"):
return a, b, c, d, e
with param_scope(**{"foo.b": 2}):
self.assertEqual(foo(1), (1, 2, 2.0, False, "str"))
with param_scope(**{"foo.c": 3.0}):
self.assertEqual(foo(1), (1, 1, 3.0, False, "str"))
def test_auto_param_func2(self):
@auto_param("foo")
def foo(a, b=1, c=2.0, d=False, e="str"):
return a, b, c, d, e
with param_scope():
param_scope.foo.b = 2
self.assertEqual(foo(1), (1, 2, 2.0, False, "str"))
param_scope.foo.c = 3.0
self.assertEqual(foo(1), (1, 2, 3.0, False, "str"))
self.assertEqual(foo(1), (1, 1, 2.0, False, "str"))