Skip to content

Commit

Permalink
Merge pull request #2340 from emailweixu/boot_bias_layer
Browse files Browse the repository at this point in the history
Fix handling of boot_bias_parameter for recurrent_group in v2 API
  • Loading branch information
emailweixu committed Jun 1, 2017
2 parents c562e57 + 02a509f commit 303d266
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions python/paddle/v2/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def dfs_travel(layer_name):
return layer_names


def __get_used_parameters__(layer_names):
def __get_used_parameters__(layer_names, sub_models):
parameter_names = set()
for name in layer_names:
l = cp.g_layer_map[name]
Expand All @@ -161,6 +161,12 @@ def __get_used_parameters__(layer_names):
parameter_names.add(inp.input_parameter_name)
if l.bias_parameter_name:
parameter_names.add(l.bias_parameter_name)

for sub_model in sub_models:
for mem in sub_model.memories:
if mem.HasField("boot_bias_parameter_name"):
parameter_names.add(mem.boot_bias_parameter_name)

return parameter_names


Expand Down Expand Up @@ -236,7 +242,6 @@ def parse_network(output_layers, extra_layers=None):
layer_names = __get_used_layers__(output_layers + extra_layers)
submodel_names = __get_used_submodels__(layer_names)
submodel_names.add('root')
parameter_names = __get_used_parameters__(layer_names)
evaluator_names = __get_used_evaluators__(layer_names)
input_layer_names = set()
output_layer_names = set()
Expand All @@ -251,10 +256,6 @@ def parse_network(output_layers, extra_layers=None):
model_config.input_layer_names.append(l.name)
input_layer_names.add(l.name)

for p in cp.g_config.model_config.parameters:
if p.name in parameter_names:
model_config.parameters.extend([p])

for layer in output_layers:
model_config.output_layer_names.append(layer.full_name)
output_layer_names.add(layer.full_name)
Expand All @@ -269,6 +270,13 @@ def parse_network(output_layers, extra_layers=None):
output_layer_names, evaluator_names)
model_config.sub_models.extend([s])

parameter_names = __get_used_parameters__(layer_names,
model_config.sub_models)

for p in cp.g_config.model_config.parameters:
if p.name in parameter_names:
model_config.parameters.extend([p])

return model_config


Expand Down

0 comments on commit 303d266

Please sign in to comment.