Skip to content

Commit

Permalink
change of names and fixes for recursive_destroys_finder
Browse files Browse the repository at this point in the history
  • Loading branch information
ReyhaneAskari committed Jun 2, 2017
1 parent 590f2f6 commit 10573b9
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 18 deletions.
11 changes: 6 additions & 5 deletions theano/compile/function_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,12 @@ def __init__(self, protected):
self.protected = list(protected)

def validate(self, fgraph):
if config.cycle_detection == 'fast' and hasattr(fgraph, 'fast_destroyers_check'):
if fgraph.fast_destroyers_check(self.protected):
if config.cycle_detection == 'fast' and hasattr(fgraph, 'has_destroyers'):
if fgraph.has_destroyers(self.protected):
raise gof.InconsistencyError("Trying to destroy a protected"
"Variable.")

else:
return True
if not hasattr(fgraph, 'destroyers'):
return True
for r in self.protected + list(fgraph.outputs):
Expand Down Expand Up @@ -1091,7 +1092,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):

# We can't use fgraph.inputs as this don't include Constant Value.
all_graph_inputs = gof.graph.inputs(fgraph.outputs)
has_destroyers = hasattr(fgraph, 'get_destroyers_of')
has_get_destroyers = hasattr(fgraph, 'get_destroyers_of')

for i in xrange(len(fgraph.outputs)):
views_of_output_i = set()
Expand Down Expand Up @@ -1122,7 +1123,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
# being updated
if input_j in updated_fgraph_inputs:
continue
if input_j in views_of_output_i and not (has_destroyers and fgraph.get_destroyers_of(input_j)):
if input_j in views_of_output_i and not (has_get_destroyers and fgraph.get_destroyers_of(input_j)):
# We don't put deep_copy_op if the input and the
# output have borrow==True
if input_j in fgraph.inputs:
Expand Down
25 changes: 12 additions & 13 deletions theano/gof/destroyhandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
<unknown>
"""
pickle_rm_attr = ["destroyers", "fast_destroyers_check"]
pickle_rm_attr = ["destroyers", "has_destroyers"]

def __init__(self, do_imports_on_attach=True, algo=None):
self.fgraph = None
Expand Down Expand Up @@ -395,24 +395,23 @@ def get_destroyers_of(r):
fgraph.destroyers = get_destroyers_of

def recursive_destroys_finder(clients_list):
for client in clients_list:
# client is a tuple (I don't know if its size is always one)
for item in client:
if item.op.destroy_map:
for (app, idx) in clients_list:
if app == 'output':
continue
destroy_maps = getattr(app.op, 'destroy_map', {}).values()
if idx in [dmap for sublist in destroy_maps for dmap in sublist]:
return True
for var in getattr(app.op, 'view_map', {}).keys():
if recursive_destroys_finder(app.outputs[var].clients):
return True
if len(item.outputs) == 0:
return False
for output in item.outputs:
if recursive_destroys_finder(output.clients):
return True
return False

def fast_destroyers_check(protected_list):
def has_destroyers(protected_list):
for protected_var in protected_list:
if recursive_destroys_finder(protected_var.clients):
return True

fgraph.fast_destroyers_check = fast_destroyers_check
fgraph.has_destroyers = has_destroyers

def refresh_droot_impact(self):
"""
Expand All @@ -436,7 +435,7 @@ def on_detach(self, fgraph):
del self.stale_droot
assert self.fgraph.destroyer_handler is self
delattr(self.fgraph, 'destroyers')
delattr(self.fgraph, 'fast_destroyers_check')
delattr(self.fgraph, 'has_destroyers')
delattr(self.fgraph, 'destroy_handler')
self.fgraph = None

Expand Down

0 comments on commit 10573b9

Please sign in to comment.