This repository has been archived by the owner on Jan 19, 2019. It is now read-only.
/
optimizers.py
57 lines (48 loc) · 2.37 KB
/
optimizers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
r"""
It turns out that Keras' design is somewhat crazy\*, and there is no list of
optimizers that you can just import from Keras. So, this module specifies a
list, and a helper function or two for dealing with optimizer parameters.
Unfortunately, this means that we have a list that must be kept in sync with
Keras. Oh well.
\* Have you seen their get_from_module() method? See here:
https://github.com/fchollet/keras/blob/6e42b0e4a77fb171295b541a6ae9a3a4a79f9c87/keras/utils/generic_utils.py#L10.
That method means I could pass in 'clip_norm' as an optimizer, and it would try
to use that function as an optimizer. It also means there is no simple list of
implemented optimizers I can grab.
\* I should also note that Keras is an incredibly useful library that does a lot
of things really well. It just has a few quirks...
"""
import logging
from typing import Union
# pylint: disable=no-name-in-module
from tensorflow.python.training.gradient_descent import GradientDescentOptimizer
from tensorflow.python.training.rmsprop import RMSPropOptimizer
from tensorflow.python.training.adadelta import AdadeltaOptimizer
from tensorflow.python.training.adagrad import AdagradOptimizer
from tensorflow.python.training.adam import AdamOptimizer
# pylint: enable=no-name-in-module
from ..common.params import Params
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
optimizers = { # pylint: disable=invalid-name
'sgd': GradientDescentOptimizer,
'rmsprop': RMSPropOptimizer,
'adagrad': AdagradOptimizer,
'adadelta': AdadeltaOptimizer,
'adam': AdamOptimizer
}
def optimizer_from_params(params: Union[Params, str]):
"""
This method converts from a parameter object like we use in our Trainer
code into an optimizer object suitable for use with Keras. The simplest
case for both of these is a string that shows up in `optimizers` above - if
`params` is just one of those strings, we return it, and everyone is happy.
If not, we assume `params` is a Dict[str, Any], with a "type" key, where
the value for "type" must be one of those strings above. We take the rest
of the parameters and pass them to the optimizer's constructor.
"""
if isinstance(params, str):
optimizer = params
params = {}
else:
optimizer = params.pop_choice("type", optimizers.keys())
return optimizers[optimizer](**params)