Skip to content

Commit

Permalink
Fix V2 API
Browse files Browse the repository at this point in the history
  • Loading branch information
xuwei06 committed May 26, 2017
1 parent da83d28 commit 7d0355c
Show file tree
Hide file tree
Showing 11 changed files with 278 additions and 623 deletions.
1 change: 1 addition & 0 deletions paddle/parameter/Parameter.h
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ class Parameter {
std::vector<std::shared_ptr<IParameterUpdaterHook>> updaterHooks_;

public:
void setSharedCount(int cnt) { sharedCount_ = cnt; }
int getSharedCount() { return sharedCount_; }

bool isSparse() { return config_.is_sparse(); }
Expand Down
31 changes: 17 additions & 14 deletions python/paddle/trainer/config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3371,7 +3371,7 @@ def Import(config_file, local_args={}):
return Import


settings = dict(
default_settings = dict(
batch_size=None,
mini_batch_size=None,
algorithm='async_sgd',
Expand Down Expand Up @@ -3404,6 +3404,8 @@ def Import(config_file, local_args={}):
adam_beta2=0.999,
adam_epsilon=1e-8, )

settings = copy.deepcopy(default_settings)

settings_deprecated = dict(usage_ratio=1., )

trainer_settings = dict(
Expand Down Expand Up @@ -3544,23 +3546,32 @@ def update_g_config():
return g_config


def parse_config(trainer_config, config_arg_str):
def begin_parse(config_arg_str=''):
'''
@param trainer_config: can be a string of config file name or a function name
with config logic
@param config_arg_str: a string of the form var1=val1,var2=val2. It will be
passed to config script as a dictionary CONFIG_ARGS
'''
init_config_environment()
for hook in _parse_config_hooks:
hook()

config_args = {}

logger.findCaller = find_caller
logger.fatal = my_fatal

g_config.model_config.type = "nn"

global g_current_submodel, g_root_submodel
g_root_submodel = g_config.model_config.sub_models.add()
g_root_submodel.name = 'root'
g_root_submodel.is_recurrent_layer_group = False
g_current_submodel = g_root_submodel


def parse_config(trainer_config, config_arg_str):
begin_parse(config_arg_str)

config_args = {}

if config_arg_str:
config_args = dict([f.split('=') for f in config_arg_str.split(',')])

Expand All @@ -3573,14 +3584,6 @@ def parse_config(trainer_config, config_arg_str):
extension_module = importlib(extension_module_name)
g_extended_config_funcs = extension_module.get_config_funcs(g_config)

g_config.model_config.type = 'nn'

global g_current_submodel, g_root_submodel
g_root_submodel = g_config.model_config.sub_models.add()
g_root_submodel.name = 'root'
g_root_submodel.is_recurrent_layer_group = False
g_current_submodel = g_root_submodel

if hasattr(trainer_config, '__call__'):
trainer_config.func_globals.update(
make_config_environment("", config_args))
Expand Down
22 changes: 18 additions & 4 deletions python/paddle/trainer_config_helpers/config_parser_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import paddle.trainer.config_parser as config_parser
from paddle.proto.TrainerConfig_pb2 import OptimizationConfig

'''
This file is a wrapper of formal config_parser. The main idea of this file is to
This file is a wrapper of formal config_parser. The main idea of this file is to
separete different config logic into different function, such as network configuration
and optimizer configuration.
'''

__all__ = [
"parse_trainer_config", "parse_network_config", "parse_optimizer_config"
"parse_trainer_config", "parse_network_config", "parse_optimizer_config",
"reset_parser"
]


Expand All @@ -34,5 +38,15 @@ def parse_network_config(network_conf, config_arg_str=''):


def parse_optimizer_config(optimizer_conf, config_arg_str=''):
config = config_parser.parse_config(optimizer_conf, config_arg_str)
return config.opt_config
config_parser.settings = copy.deepcopy(config_parser.default_settings)
optimizer_conf()
opt_config = OptimizationConfig()
for k, v in config_parser.settings.iteritems():
if v is None:
continue
opt_config.__setattr__(k, v)
return opt_config


def reset_parser():
config_parser.begin_parse()
6 changes: 6 additions & 0 deletions python/paddle/trainer_config_helpers/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ def __init__(self,
assert size is not None
assert LayerType.is_layer_type(layer_type)
self.name = name
self.full_name = MakeLayerNameInSubmodel(name)
self.layer_type = layer_type
if parents is not None and type(parents) != list:
parents = [parents]
Expand Down Expand Up @@ -3489,6 +3490,11 @@ def map_in_links(x):

RecurrentLayerGroupEnd(name=name)

for layer_out in layer_outs:
# Thee previous full_name is the name is the rnn group
# We need a full_name outside the rnn group
layer_out.full_name = MakeLayerNameInSubmodel(layer_out.name)

if len(layer_outs) == 1:
return layer_outs[0]
else:
Expand Down
14 changes: 1 addition & 13 deletions python/paddle/v2/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,9 @@ def convert_to_new_name(nm):

for __ev_name__ in filter(lambda x: x.endswith('_evaluator'), evs.__all__):
__ev__ = getattr(evs, __ev_name__)
if hasattr(__ev__, 'argspec'):
argspec = __ev__.argspec
else:
argspec = inspect.getargspec(__ev__)
parent_names = filter(lambda x: x in ['input', 'label', 'weight'],
argspec.args)
v2_ev = __convert_to_v2__(
__ev_name__,
parent_names=parent_names,
is_default_name='name' in argspec.args,
attach_parent=True)

__new_name__ = convert_to_new_name(__ev_name__)

globals()[__new_name__] = v2_ev
globals()[__new_name__] = __ev__
globals()[__new_name__].__name__ = __new_name__
__all__.append(__new_name__)

Expand Down
Loading

0 comments on commit 7d0355c

Please sign in to comment.