diff --git a/demo/mnist/api_train.py b/demo/mnist/api_train.py index 48ba61c47da41..8fa286b5f940f 100644 --- a/demo/mnist/api_train.py +++ b/demo/mnist/api_train.py @@ -9,11 +9,29 @@ import py_paddle.swig_paddle as api from py_paddle import DataProviderConverter import paddle.trainer.PyDataProvider2 as dp -import paddle.trainer.config_parser import numpy as np import random from mnist_util import read_from_mnist +import paddle.trainer_config_helpers.config_parser as config_parser +from paddle.trainer_config_helpers import * + + +def optimizer_config(): + settings( + learning_rate=1e-4, learning_method=AdamOptimizer(), batch_size=1000) + + +def network_config(): + imgs = data_layer(name='pixel', size=784) + hidden1 = fc_layer(input=imgs, size=200) + hidden2 = fc_layer(input=hidden1, size=200) + inference = fc_layer(input=hidden2, size=10, act=SoftmaxActivation()) + cost = classification_cost( + input=inference, label=data_layer( + name='label', size=10)) + outputs(cost) + def init_parameter(network): assert isinstance(network, api.GradientMachine) @@ -54,20 +72,20 @@ def input_order_converter(generator): def main(): api.initPaddle("-use_gpu=false", "-trainer_count=4") # use 4 cpu cores - config = paddle.trainer.config_parser.parse_config( - 'simple_mnist_network.py', '') # get enable_types for each optimizer. # enable_types = [value, gradient, momentum, etc] # For each optimizer(SGD, Adam), GradientMachine should enable different # buffers. - opt_config = api.OptimizationConfig.createFromProto(config.opt_config) + opt_config_proto = config_parser.parse_optimizer_config(optimizer_config) + opt_config = api.OptimizationConfig.createFromProto(opt_config_proto) _temp_optimizer_ = api.ParameterOptimizer.create(opt_config) enable_types = _temp_optimizer_.getParameterTypes() # Create Simple Gradient Machine. + model_config = config_parser.parse_network_config(network_config) m = api.GradientMachine.createFromConfigProto( - config.model_config, api.CREATE_MODE_NORMAL, enable_types) + model_config, api.CREATE_MODE_NORMAL, enable_types) # This type check is not useful. Only enable type hint in IDE. # Such as PyCharm diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index 2eb7b17a0b40e..674b5ac58b6fe 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -3416,8 +3416,35 @@ def register_parse_config_hook(f): _parse_config_hooks.add(f) -def parse_config(config_file, config_arg_str): +def update_g_config(): ''' + Update g_config after execute config_file or config_functions. + ''' + for k, v in settings.iteritems(): + if v is None: + continue + g_config.opt_config.__setattr__(k, v) + + for k, v in trainer_settings.iteritems(): + if v is None: + continue + g_config.__setattr__(k, v) + + for name in g_config.model_config.input_layer_names: + assert name in g_layer_map, \ + 'input name "%s" does not correspond to a layer name' % name + assert (g_layer_map[name].type == "data" or g_layer_map[name].type == "data_trim"), \ + 'The type of input layer "%s" is not "data"' % name + for name in g_config.model_config.output_layer_names: + assert name in g_layer_map, \ + 'input name "%s" does not correspond to a layer name' % name + return g_config + + +def parse_config(trainer_config, config_arg_str): + ''' + @param trainer_config: can be a string of config file name or a function name + with config logic @param config_arg_str: a string of the form var1=val1,var2=val2. It will be passed to config script as a dictionary CONFIG_ARGS ''' @@ -3451,45 +3478,20 @@ def parse_config(config_file, config_arg_str): g_root_submodel.is_recurrent_layer_group = False g_current_submodel = g_root_submodel - # for paddle on spark, need support non-file config. - # you can use parse_config like below: - # - # from paddle.trainer.config_parser import parse_config - # def configs(): - # #your paddle config code, which is same as config file. - # - # config = parse_config(configs, "is_predict=1") - # # then you get config proto object. - if hasattr(config_file, '__call__'): - config_file.func_globals.update( + if hasattr(trainer_config, '__call__'): + trainer_config.func_globals.update( make_config_environment("", config_args)) - config_file() + trainer_config() else: - execfile(config_file, make_config_environment(config_file, config_args)) - for k, v in settings.iteritems(): - if v is None: - continue - g_config.opt_config.__setattr__(k, v) - - for k, v in trainer_settings.iteritems(): - if v is None: - continue - g_config.__setattr__(k, v) + execfile(trainer_config, + make_config_environment(trainer_config, config_args)) - for name in g_config.model_config.input_layer_names: - assert name in g_layer_map, \ - 'input name "%s" does not correspond to a layer name' % name - assert (g_layer_map[name].type == "data" or g_layer_map[name].type == "data_trim"), \ - 'The type of input layer "%s" is not "data"' % name - for name in g_config.model_config.output_layer_names: - assert name in g_layer_map, \ - 'input name "%s" does not correspond to a layer name' % name - return g_config + return update_g_config() -def parse_config_and_serialize(config_file, config_arg_str): +def parse_config_and_serialize(trainer_config, config_arg_str): try: - config = parse_config(config_file, config_arg_str) + config = parse_config(trainer_config, config_arg_str) #logger.info(config) return config.SerializeToString() except: diff --git a/python/paddle/trainer_config_helpers/__init__.py b/python/paddle/trainer_config_helpers/__init__.py index a2335768b92b6..84ed40a036a18 100644 --- a/python/paddle/trainer_config_helpers/__init__.py +++ b/python/paddle/trainer_config_helpers/__init__.py @@ -20,6 +20,7 @@ from networks import * from optimizers import * from attrs import * +from config_parser import * # This will enable operator overload for LayerOutput import math as layer_math diff --git a/python/paddle/trainer_config_helpers/config_parser.py b/python/paddle/trainer_config_helpers/config_parser.py new file mode 100644 index 0000000000000..4b91b8d2824cd --- /dev/null +++ b/python/paddle/trainer_config_helpers/config_parser.py @@ -0,0 +1,38 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.trainer.config_parser as config_parser +''' +This file is a wrapper of formal config_parser. The main idea of this file is to +separete different config logic into different function, such as network configuration + and optimizer configuration. +''' + +__all__ = [ + "parse_trainer_config", "parse_network_config", "parse_optimizer_config" +] + + +def parse_trainer_config(trainer_conf, config_arg_str): + return config_parser.parse_config(trainer_conf, config_arg_str) + + +def parse_network_config(network_conf): + config = config_parser.parse_config(network_conf, '') + return config.model_config + + +def parse_optimizer_config(optimizer_conf): + config = config_parser.parse_config(optimizer_conf, '') + return config.opt_config