Skip to content

Commit

Permalink
add config_parser in trainer_config_helpers to seperate trainer config
Browse files Browse the repository at this point in the history
  • Loading branch information
jacquesqiao authored and reyoung committed Dec 22, 2016
1 parent 3a80272 commit 843b63b
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 39 deletions.
28 changes: 23 additions & 5 deletions demo/mnist/api_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,29 @@
import py_paddle.swig_paddle as api
from py_paddle import DataProviderConverter
import paddle.trainer.PyDataProvider2 as dp
import paddle.trainer.config_parser
import numpy as np
import random
from mnist_util import read_from_mnist

import paddle.trainer_config_helpers.config_parser as config_parser
from paddle.trainer_config_helpers import *


def optimizer_config():
settings(
learning_rate=1e-4, learning_method=AdamOptimizer(), batch_size=1000)


def network_config():
imgs = data_layer(name='pixel', size=784)
hidden1 = fc_layer(input=imgs, size=200)
hidden2 = fc_layer(input=hidden1, size=200)
inference = fc_layer(input=hidden2, size=10, act=SoftmaxActivation())
cost = classification_cost(
input=inference, label=data_layer(
name='label', size=10))
outputs(cost)


def init_parameter(network):
assert isinstance(network, api.GradientMachine)
Expand Down Expand Up @@ -54,20 +72,20 @@ def input_order_converter(generator):

def main():
api.initPaddle("-use_gpu=false", "-trainer_count=4") # use 4 cpu cores
config = paddle.trainer.config_parser.parse_config(
'simple_mnist_network.py', '')

# get enable_types for each optimizer.
# enable_types = [value, gradient, momentum, etc]
# For each optimizer(SGD, Adam), GradientMachine should enable different
# buffers.
opt_config = api.OptimizationConfig.createFromProto(config.opt_config)
opt_config_proto = config_parser.parse_optimizer_config(optimizer_config)
opt_config = api.OptimizationConfig.createFromProto(opt_config_proto)
_temp_optimizer_ = api.ParameterOptimizer.create(opt_config)
enable_types = _temp_optimizer_.getParameterTypes()

# Create Simple Gradient Machine.
model_config = config_parser.parse_network_config(network_config)
m = api.GradientMachine.createFromConfigProto(
config.model_config, api.CREATE_MODE_NORMAL, enable_types)
model_config, api.CREATE_MODE_NORMAL, enable_types)

# This type check is not useful. Only enable type hint in IDE.
# Such as PyCharm
Expand Down
70 changes: 36 additions & 34 deletions python/paddle/trainer/config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3416,8 +3416,35 @@ def register_parse_config_hook(f):
_parse_config_hooks.add(f)


def parse_config(config_file, config_arg_str):
def update_g_config():
'''
Update g_config after execute config_file or config_functions.
'''
for k, v in settings.iteritems():
if v is None:
continue
g_config.opt_config.__setattr__(k, v)

for k, v in trainer_settings.iteritems():
if v is None:
continue
g_config.__setattr__(k, v)

for name in g_config.model_config.input_layer_names:
assert name in g_layer_map, \
'input name "%s" does not correspond to a layer name' % name
assert (g_layer_map[name].type == "data" or g_layer_map[name].type == "data_trim"), \
'The type of input layer "%s" is not "data"' % name
for name in g_config.model_config.output_layer_names:
assert name in g_layer_map, \
'input name "%s" does not correspond to a layer name' % name
return g_config


def parse_config(trainer_config, config_arg_str):
'''
@param trainer_config: can be a string of config file name or a function name
with config logic
@param config_arg_str: a string of the form var1=val1,var2=val2. It will be
passed to config script as a dictionary CONFIG_ARGS
'''
Expand Down Expand Up @@ -3451,45 +3478,20 @@ def parse_config(config_file, config_arg_str):
g_root_submodel.is_recurrent_layer_group = False
g_current_submodel = g_root_submodel

# for paddle on spark, need support non-file config.
# you can use parse_config like below:
#
# from paddle.trainer.config_parser import parse_config
# def configs():
# #your paddle config code, which is same as config file.
#
# config = parse_config(configs, "is_predict=1")
# # then you get config proto object.
if hasattr(config_file, '__call__'):
config_file.func_globals.update(
if hasattr(trainer_config, '__call__'):
trainer_config.func_globals.update(
make_config_environment("", config_args))
config_file()
trainer_config()
else:
execfile(config_file, make_config_environment(config_file, config_args))
for k, v in settings.iteritems():
if v is None:
continue
g_config.opt_config.__setattr__(k, v)

for k, v in trainer_settings.iteritems():
if v is None:
continue
g_config.__setattr__(k, v)
execfile(trainer_config,
make_config_environment(trainer_config, config_args))

for name in g_config.model_config.input_layer_names:
assert name in g_layer_map, \
'input name "%s" does not correspond to a layer name' % name
assert (g_layer_map[name].type == "data" or g_layer_map[name].type == "data_trim"), \
'The type of input layer "%s" is not "data"' % name
for name in g_config.model_config.output_layer_names:
assert name in g_layer_map, \
'input name "%s" does not correspond to a layer name' % name
return g_config
return update_g_config()


def parse_config_and_serialize(config_file, config_arg_str):
def parse_config_and_serialize(trainer_config, config_arg_str):
try:
config = parse_config(config_file, config_arg_str)
config = parse_config(trainer_config, config_arg_str)
#logger.info(config)
return config.SerializeToString()
except:
Expand Down
1 change: 1 addition & 0 deletions python/paddle/trainer_config_helpers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from networks import *
from optimizers import *
from attrs import *
from config_parser import *

# This will enable operator overload for LayerOutput
import math as layer_math
38 changes: 38 additions & 0 deletions python/paddle/trainer_config_helpers/config_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle.trainer.config_parser as config_parser
'''
This file is a wrapper of formal config_parser. The main idea of this file is to
separete different config logic into different function, such as network configuration
and optimizer configuration.
'''

__all__ = [
"parse_trainer_config", "parse_network_config", "parse_optimizer_config"
]


def parse_trainer_config(trainer_conf, config_arg_str):
return config_parser.parse_config(trainer_conf, config_arg_str)


def parse_network_config(network_conf):
config = config_parser.parse_config(network_conf, '')
return config.model_config


def parse_optimizer_config(optimizer_conf):
config = config_parser.parse_config(optimizer_conf, '')
return config.opt_config

0 comments on commit 843b63b

Please sign in to comment.