Skip to content

Commit

Permalink
Merge pull request #6000 from ReyhaneAskari/new_destroy_handler
Browse files Browse the repository at this point in the history
New destroy handler
  • Loading branch information
abergeron committed Aug 14, 2017
2 parents 26d4705 + f03092a commit e082176
Show file tree
Hide file tree
Showing 20 changed files with 147 additions and 38 deletions.
9 changes: 9 additions & 0 deletions doc/faq.txt
Expand Up @@ -43,6 +43,8 @@ CPUs. In fact, Theano asks g++ what are the equivalent flags it uses, and re-use
them directly.


.. _faster-theano-function-compilation:

Faster Theano Function Compilation
----------------------------------

Expand All @@ -67,6 +69,13 @@ compilation but it will also use more memory because
resulting in a trade off between speed of compilation and memory
usage.

Alternatively, if the graph is big, using the flag ``cycle_detection=fast``
will speedup the computations by removing some of the inplace
optimizations. This would allow theano to skip a time consuming cycle
detection algorithm. If the graph is big enough,we suggest that you use
this flag instead of ``optimizer_excluding=inplace``. It will result in a
computation time that is in between fast compile and fast run.

Theano flag `reoptimize_unpickled_function` controls if an unpickled
theano function should reoptimize its graph or not. Theano users can
use the standard python pickle tools to save a compiled theano
Expand Down
3 changes: 2 additions & 1 deletion doc/tutorial/modes.txt
Expand Up @@ -225,7 +225,8 @@ stabilize "+++++" "++" Only applies stability opts
================= ============ ============== ==================================================

For a detailed list of the specific optimizations applied for each of these
optimizers, see :ref:`optimizations`. Also, see :ref:`unsafe_optimization`.
optimizers, see :ref:`optimizations`. Also, see :ref:`unsafe_optimization` and
:ref:`faster-theano-function-compilation` for other trade-off.


.. _using_debugmode:
Expand Down
39 changes: 20 additions & 19 deletions theano/compile/debugmode.py
Expand Up @@ -2273,25 +2273,26 @@ def __init__(self, inputs, outputs, mode,
"of", len(li), "events was stable.",
file=sys.stderr)
self.fgraph = fgraph
destroy_handler_added = False
for feature in fgraph._features:
if isinstance(feature, gof.DestroyHandler):
destroy_handler_added = True
break
if not destroy_handler_added:
fgraph.attach_feature(gof.DestroyHandler())
for o in fgraph.outputs:
try:
with change_flags(compute_test_value=config.compute_test_value_opt):
fgraph.replace_validate(o, _output_guard(o), reason='output_guard')
raise Exception("Output variable %s required output_guard, "
"how was this output left unprotected against "
"destructive operations?" % o)

except gof.InconsistencyError:
# This output is already impossible to destroy.
# No guard necessary
pass
if theano.config.cycle_detection == 'regular':
destroy_handler_added = False
for feature in fgraph._features:
if isinstance(feature, gof.DestroyHandler):
destroy_handler_added = True
break
if not destroy_handler_added:
fgraph.attach_feature(gof.DestroyHandler())
for o in fgraph.outputs:
try:
with change_flags(compute_test_value=config.compute_test_value_opt):
fgraph.replace_validate(o, _output_guard(o), reason='output_guard')
raise Exception("Output variable %s required output_guard, "
"how was this output left unprotected against "
"destructive operations?" % o)

except gof.InconsistencyError:
# This output is already impossible to destroy.
# No guard necessary
pass

linker = _Linker(self)

Expand Down
11 changes: 8 additions & 3 deletions theano/compile/function_module.py
Expand Up @@ -132,6 +132,11 @@ def __init__(self, protected):
self.protected = list(protected)

def validate(self, fgraph):
if config.cycle_detection == 'fast' and hasattr(fgraph, 'has_destroyers'):
if fgraph.has_destroyers(self.protected):
raise gof.InconsistencyError("Trying to destroy a protected"
"Variable.")
return True
if not hasattr(fgraph, 'destroyers'):
return True
for r in self.protected + list(fgraph.outputs):
Expand Down Expand Up @@ -190,7 +195,7 @@ def std_fgraph(input_specs, output_specs, accept_inplace=False):
for spec, input in zip(input_specs, fgraph.inputs)
if not (spec.mutable or
(hasattr(fgraph, 'destroyers') and
fgraph.destroyers(input)))))
fgraph.has_destroyers([input])))))

# If named nodes are replaced, keep the name
for feature in std_fgraph.features:
Expand Down Expand Up @@ -1111,7 +1116,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):

# We can't use fgraph.inputs as this don't include Constant Value.
all_graph_inputs = gof.graph.inputs(fgraph.outputs)
has_destroyers = hasattr(fgraph, 'get_destroyers_of')
has_destroyers_attr = hasattr(fgraph, 'has_destroyers')

for i in xrange(len(fgraph.outputs)):
views_of_output_i = set()
Expand Down Expand Up @@ -1142,7 +1147,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
# being updated
if input_j in updated_fgraph_inputs:
continue
if input_j in views_of_output_i and not (has_destroyers and fgraph.get_destroyers_of(input_j)):
if input_j in views_of_output_i and not (has_destroyers_attr and fgraph.has_destroyers([input_j])):
# We don't put deep_copy_op if the input and the
# output have borrow==True
if input_j in fgraph.inputs:
Expand Down
2 changes: 1 addition & 1 deletion theano/configdefaults.py
Expand Up @@ -1575,7 +1575,7 @@ def filter_vm_lazy(val):

"The interaction of which one give the lower peak memory usage is"
"complicated and not predictable, so if you are close to the peak"
"memory usage, triyng both could give you a small gain. ",
"memory usage, triyng both could give you a small gain.",
EnumStr('regular', 'fast'),
in_c_key=False)

Expand Down
56 changes: 47 additions & 9 deletions theano/gof/destroyhandler.py
Expand Up @@ -250,7 +250,7 @@ def fast_inplace_check(inputs):

inputs = [i for i in inputs if
not isinstance(i, graph.Constant) and
not fgraph.destroyers(i) and
not fgraph.has_destroyers([i]) and
i not in protected_inputs]
return inputs

Expand Down Expand Up @@ -297,7 +297,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
<unknown>
"""
pickle_rm_attr = ["destroyers"]
pickle_rm_attr = ["destroyers", "has_destroyers"]

def __init__(self, do_imports_on_attach=True, algo=None):
self.fgraph = None
Expand Down Expand Up @@ -394,6 +394,41 @@ def get_destroyers_of(r):
return []
fgraph.destroyers = get_destroyers_of

def has_destroyers(protected_list):
if self.algo != 'fast':
droot, _, root_destroyer = self.refresh_droot_impact()
for protected_var in protected_list:
try:
root_destroyer[droot[protected_var]]
return True
except KeyError:
pass
return False

def recursive_destroys_finder(protected_var):
# protected_var is the idx'th input of app.
for (app, idx) in protected_var.clients:
if app == 'output':
continue
destroy_maps = getattr(app.op, 'destroy_map', {}).values()
# If True means that the apply node, destroys the protected_var.
if idx in [dmap for sublist in destroy_maps for dmap in sublist]:
return True
for var_idx in getattr(app.op, 'view_map', {}).keys():
if idx in app.op.view_map[var_idx]:
# We need to recursivly check the destroy_map of all the
# outputs that we have a view_map on.
if recursive_destroys_finder(app.outputs[var_idx]):
return True
return False

for protected_var in protected_list:
if recursive_destroys_finder(protected_var):
return True
return False

fgraph.has_destroyers = has_destroyers

def refresh_droot_impact(self):
"""
Makes sure self.droot, self.impact, and self.root_destroyer are up to
Expand All @@ -416,6 +451,7 @@ def on_detach(self, fgraph):
del self.stale_droot
assert self.fgraph.destroyer_handler is self
delattr(self.fgraph, 'destroyers')
delattr(self.fgraph, 'has_destroyers')
delattr(self.fgraph, 'destroy_handler')
self.fgraph = None

Expand Down Expand Up @@ -452,11 +488,11 @@ def fast_destroy(self, app, reason):
if len(v) > 0:
self.fail_validate[app] = theano.gof.InconsistencyError(
"Destroyed variable has view_map. " + str(reason))
elif d:
d = d.get(inp_idx2, [])
if len(d) > 0:
self.fail_validate[app] = theano.gof.InconsistencyError(
"Destroyed variable has destroy_map. " + str(reason))
elif d:
d = d.get(inp_idx2, [])
if len(d) > 0:
self.fail_validate[app] = theano.gof.InconsistencyError(
"Destroyed variable has destroy_map. " + str(reason))

# These 2 assertions are commented since this function is called so many times
# but they should be true.
Expand All @@ -474,13 +510,15 @@ def on_import(self, fgraph, app, reason):
# print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps)

# If it's a destructive op, add it to our watch list
if getattr(app.op, 'destroy_map', None):
dmap = getattr(app.op, 'destroy_map', None)
vmap = getattr(app.op, 'view_map', {})
if dmap:
self.destroyers.add(app)
if self.algo == 'fast':
self.fast_destroy(app, reason)

# add this symbol to the forward and backward maps
for o_idx, i_idx_list in iteritems(getattr(app.op, 'view_map', {})):
for o_idx, i_idx_list in iteritems(vmap):
if len(i_idx_list) > 1:
raise NotImplementedError(
'destroying this output invalidates multiple inputs',
Expand Down
12 changes: 12 additions & 0 deletions theano/gof/tests/test_destroyhandler.py
Expand Up @@ -11,6 +11,7 @@
from theano.gof import destroyhandler
from theano.gof.fg import FunctionGraph, InconsistencyError
from theano.gof.toolbox import ReplaceValidate
from theano.tests.unittest_tools import assertFailure_fast

from theano.configparser import change_flags

Expand Down Expand Up @@ -169,6 +170,7 @@ def test_misc():
######################


@assertFailure_fast
def test_aliased_inputs_replacement():
x, y, z = inputs()
tv = transpose_view(x)
Expand Down Expand Up @@ -200,6 +202,7 @@ def test_indestructible():
consistent(g)


@assertFailure_fast
def test_usage_loop_through_views_2():
x, y, z = inputs()
e0 = transpose_view(transpose_view(sigmoid(x)))
Expand All @@ -210,6 +213,7 @@ def test_usage_loop_through_views_2():
inconsistent(g) # we cut off the path to the sigmoid


@assertFailure_fast
def test_destroyers_loop():
# AddInPlace(x, y) and AddInPlace(y, x) should not coexist
x, y, z = inputs()
Expand Down Expand Up @@ -259,6 +263,7 @@ def test_aliased_inputs2():
inconsistent(g)


@assertFailure_fast
def test_aliased_inputs_tolerate():
x, y, z = inputs()
e = add_in_place_2(x, x)
Expand All @@ -273,13 +278,15 @@ def test_aliased_inputs_tolerate2():
inconsistent(g)


@assertFailure_fast
def test_same_aliased_inputs_ignored():
x, y, z = inputs()
e = add_in_place_3(x, x)
g = Env([x], [e], False)
consistent(g)


@assertFailure_fast
def test_different_aliased_inputs_ignored():
x, y, z = inputs()
e = add_in_place_3(x, transpose_view(x))
Expand Down Expand Up @@ -314,6 +321,7 @@ def test_indirect():
inconsistent(g)


@assertFailure_fast
def test_indirect_2():
x, y, z = inputs()
e0 = transpose_view(x)
Expand All @@ -325,6 +333,7 @@ def test_indirect_2():
consistent(g)


@assertFailure_fast
def test_long_destroyers_loop():
x, y, z = inputs()
e = dot(dot(add_in_place(x, y),
Expand Down Expand Up @@ -366,6 +375,7 @@ def test_multi_destroyers():
pass


@assertFailure_fast
def test_multi_destroyers_through_views():
x, y, z = inputs()
e = dot(add(transpose_view(z), y), add(z, x))
Expand Down Expand Up @@ -408,6 +418,7 @@ def test_usage_loop_through_views():
consistent(g)


@assertFailure_fast
def test_usage_loop_insert_views():
x, y, z = inputs()
e = dot(add_in_place(x, add(y, z)),
Expand Down Expand Up @@ -442,6 +453,7 @@ def test_value_repl_2():
consistent(g)


@assertFailure_fast
def test_multiple_inplace():
# this tests issue #5223
# there were some problems with Ops that have more than
Expand Down
2 changes: 2 additions & 0 deletions theano/gpuarray/tests/test_dnn.py
Expand Up @@ -1754,6 +1754,7 @@ def test_without_dnn_batchnorm_train_without_running_averages():
f_abstract(X, Scale, Bias, Dy)


@utt.assertFailure_fast
def test_dnn_batchnorm_train_inplace():
# test inplace_running_mean and inplace_running_var
if not dnn.dnn_available(test_ctx_name):
Expand Down Expand Up @@ -1876,6 +1877,7 @@ def test_batchnorm_inference():
utt.assert_allclose(outputs_abstract[5], outputs_ref[5], rtol=2e-3, atol=4e-5) # dvar


@utt.assertFailure_fast
def test_batchnorm_inference_inplace():
# test inplace
if not dnn.dnn_available(test_ctx_name):
Expand Down
6 changes: 6 additions & 0 deletions theano/gpuarray/tests/test_linalg.py
Expand Up @@ -175,6 +175,7 @@ def invalid_input_func():
GpuCholesky(lower=True, inplace=False)(A)
self.assertRaises(AssertionError, invalid_input_func)

@utt.assertFailure_fast
def test_diag_chol(self):
# Diagonal matrix input Cholesky test.
for lower in [True, False]:
Expand All @@ -183,6 +184,7 @@ def test_diag_chol(self):
A_val = np.diag(np.random.uniform(size=5).astype("float32") + 1)
self.compare_gpu_cholesky_to_np(A_val, lower=lower, inplace=inplace)

@utt.assertFailure_fast
def test_dense_chol_lower(self):
# Dense matrix input lower-triangular Cholesky test.
for lower in [True, False]:
Expand Down Expand Up @@ -243,6 +245,7 @@ def test_gpu_matrix_inverse(self):
A_val_inv = fn(A_val)
utt.assert_allclose(np.eye(N), np.dot(A_val_inv, A_val), atol=1e-2)

@utt.assertFailure_fast
def test_gpu_matrix_inverse_inplace(self):
N = 1000
test_rng = np.random.RandomState(seed=1)
Expand All @@ -258,6 +261,7 @@ def test_gpu_matrix_inverse_inplace(self):
fn()
utt.assert_allclose(np.eye(N), np.dot(A_val_gpu.get_value(), A_val_copy), atol=5e-3)

@utt.assertFailure_fast
def test_gpu_matrix_inverse_inplace_opt(self):
A = theano.tensor.fmatrix("A")
fn = theano.function([A], matrix_inverse(A), mode=mode_with_gpu)
Expand Down Expand Up @@ -360,6 +364,7 @@ def test_gpu_cholesky_opt(self):
assert any([isinstance(node.op, GpuMagmaCholesky)
for node in fn.maker.fgraph.toposort()])

@utt.assertFailure_fast
def test_gpu_cholesky_inplace(self):
A = self.rand_symmetric(1000)
A_gpu = gpuarray_shared_constructor(A)
Expand All @@ -375,6 +380,7 @@ def test_gpu_cholesky_inplace(self):
L = A_gpu.get_value()
utt.assert_allclose(np.dot(L, L.T), A_copy, atol=1e-3)

@utt.assertFailure_fast
def test_gpu_cholesky_inplace_opt(self):
A = theano.tensor.fmatrix("A")
fn = theano.function([A], GpuMagmaCholesky()(A), mode=mode_with_gpu)
Expand Down

0 comments on commit e082176

Please sign in to comment.