Skip to content

Commit

Permalink
Fix style
Browse files Browse the repository at this point in the history
  • Loading branch information
xuwei06 committed May 26, 2017
1 parent 7d0355c commit 0cb8a66
Showing 1 changed file with 20 additions and 20 deletions.
40 changes: 20 additions & 20 deletions python/paddle/v2/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
__all__ = ['data', 'parse_network']
__layer_map__ = {}


def __wrap__(f):
def wrapped(*args, **xargs):
out = f(*args, **xargs)
Expand All @@ -53,6 +54,7 @@ def wrapped(*args, **xargs):

return wrapped


def __need_to_keep__(name):
if name in ['StaticInput', 'LayerType', 'layer_support']:
return False
Expand Down Expand Up @@ -99,6 +101,7 @@ def __data_layer__(name, type, **kwargs):
l.data_type = type
return l


data = __wrap__(__data_layer__)

LayerV2 = v1_layers.LayerOutput
Expand All @@ -107,6 +110,7 @@ def __data_layer__(name, type, **kwargs):
def __get_used_layers__(output_layers, extra_layers=None):
layer_names = set()
parents = {}

def add_parent(child, parent):
if child in parents:
parents[child].append(parent)
Expand Down Expand Up @@ -181,28 +185,25 @@ def __get_used_evaluators__(layer_names):
return evaluator_names


def __trim_submodel__(old_submodel,
layer_names,
input_layer_names,
output_layer_names,
evaluator_names):
def __trim_submodel__(old_submodel, layer_names, input_layer_names,
output_layer_names, evaluator_names):

submodel = SubModelConfig()
submodel.name = old_submodel.name
submodel.layer_names.extend(filter(lambda x: x in layer_names,
old_submodel.layer_names))
submodel.input_layer_names.extend(filter(lambda x: x in input_layer_names,
submodel.layer_names))
submodel.output_layer_names.extend(filter(lambda x: x in output_layer_names,
submodel.layer_names))
submodel.evaluator_names.extend(filter(lambda x: x in evaluator_names,
old_submodel.evaluator_names))
submodel.layer_names.extend(
filter(lambda x: x in layer_names, old_submodel.layer_names))
submodel.input_layer_names.extend(
filter(lambda x: x in input_layer_names, submodel.layer_names))
submodel.output_layer_names.extend(
filter(lambda x: x in output_layer_names, submodel.layer_names))
submodel.evaluator_names.extend(
filter(lambda x: x in evaluator_names, old_submodel.evaluator_names))

submodel.is_recurrent_layer_group = old_submodel.is_recurrent_layer_group
submodel.reversed = old_submodel.reversed

submodel.memories.extend(filter(lambda x: x.link_name in layer_names,
old_submodel.memories))
submodel.memories.extend(
filter(lambda x: x.link_name in layer_names, old_submodel.memories))
target_inlinkid = (old_submodel.target_inlinkid
if old_submodel.HasField('target_inlinkid') else -1)
in_links = []
Expand All @@ -213,8 +214,8 @@ def __trim_submodel__(old_submodel,
target_inlinkid = len(in_links) - 1
submodel.in_links.extend(in_links)

submodel.out_links.extend(filter(lambda x: x.link_name in layer_names,
old_submodel.out_links))
submodel.out_links.extend(
filter(lambda x: x.link_name in layer_names, old_submodel.out_links))
if old_submodel.HasField('generator'):
submodel.generator.CopyFrom(old_submodel.generator)

Expand Down Expand Up @@ -264,9 +265,8 @@ def parse_network(output_layers, extra_layers=None):

for s in cp.g_config.model_config.sub_models:
if s.name in submodel_names:
s = __trim_submodel__(
s, layer_names, input_layer_names, output_layer_names,
evaluator_names)
s = __trim_submodel__(s, layer_names, input_layer_names,
output_layer_names, evaluator_names)
model_config.sub_models.extend([s])

return model_config
Expand Down

0 comments on commit 0cb8a66

Please sign in to comment.