From 2368ed3e7f14801785280038094ac99773e876a6 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Sat, 3 Jun 2023 17:07:04 -0500 Subject: [PATCH] Fix bad initial value shape assumptions in save_mem_new_scan --- aesara/scan/rewriting.py | 170 +++++++++++++++++++++-------------- tests/scan/test_basic.py | 16 ---- tests/scan/test_rewriting.py | 146 ++++++++++++++++++++++++++++++ 3 files changed, 247 insertions(+), 85 deletions(-) diff --git a/aesara/scan/rewriting.py b/aesara/scan/rewriting.py index 72b30b2673..3fbffbf882 100644 --- a/aesara/scan/rewriting.py +++ b/aesara/scan/rewriting.py @@ -4,7 +4,7 @@ import dataclasses from itertools import chain from sys import maxsize -from typing import Dict, List, Optional, Tuple, cast +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast import numpy as np @@ -36,7 +36,6 @@ from aesara.scan.utils import ( ScanArgs, compress_outs, - expand_empty, reconstruct_graph, safe_new, scan_can_remove_outs, @@ -48,18 +47,22 @@ from aesara.tensor.rewriting.basic import constant_folding, local_useless_switch from aesara.tensor.rewriting.elemwise import local_upcast_elemwise_constant_inputs from aesara.tensor.rewriting.math import local_abs_merge, local_mul_switch_sink -from aesara.tensor.shape import shape +from aesara.tensor.shape import shape, shape_tuple from aesara.tensor.subtensor import ( IncSubtensor, Subtensor, get_canonical_form_slice, get_idx_list, get_slice_elements, + indices_from_subtensor, set_subtensor, ) from aesara.tensor.var import TensorConstant, get_unique_value +if TYPE_CHECKING: + from aesara.tensor.var import TensorVariable + list_opt_slice = [ local_abs_merge, local_mul_switch_sink, @@ -1103,6 +1106,72 @@ def sanitize(x): return at.as_tensor_variable(x) +def reshape_output_storage( + out_storage: "TensorVariable", + steps_needed: "TensorVariable", + tap_spread: int, +) -> "TensorVariable": + """Resize the first dimension of ``storage`` in ``set_subtensor(storage[:tap_spread], initial_tap_vals)``. + + This is used by `save_mem_new_scan` to reduce the amount of storage + (pre)allocated for `Scan` output arrays (i.e. ``storage`` is assumed to be + an `AllocEmpty` output). + + Parameters + ---------- + out_storage + This corresponds to a graph with the form + ``set_subtensor(storage[:tap_spread], initial_tap_vals)``. + tap_spread + The spread of the relevant tap. This will generally be the length of + ``initial_tap_vals``, but sometimes not (e.g. because the initial + values broadcast across the indices/slice). + + Returns + ------- + Return a graph like + ``set_subtensor(new_storage[:tap_spread], initial_tap_vals)``, + where ``new_storage`` is an `AllocEmpty` with a first + dimension having length ``maximum(steps_needed_var, tap_spread)``. + + """ + out_storage_node = out_storage.owner + if ( + out_storage_node + and isinstance(out_storage_node.op, IncSubtensor) + and out_storage_node.op.set_instead_of_inc + and len(out_storage_node.op.idx_list) == 1 + and isinstance(out_storage_node.op.idx_list[0], slice) + ): + # The "fill" value of the `IncSubtensor` across the + # slice. This should generally consist of the initial + # values. + initial_tap_vals = out_storage_node.inputs[1] + + storage_slice = indices_from_subtensor( + out_storage_node.inputs[2:], out_storage_node.op.idx_list + ) + inner_storage_var = out_storage_node.inputs[0] + + # Why this size exactly? (N.B. This is what the original Theano logic ultimately did.) + max_storage_size = at.switch( + at.lt(steps_needed, tap_spread), steps_needed + 2 * tap_spread, steps_needed + ) + new_inner_storage_var = at.empty( + ( + max_storage_size, + *shape_tuple(inner_storage_var)[1:], + ), + dtype=initial_tap_vals.dtype, + ) + res = at.set_subtensor(new_inner_storage_var[storage_slice], initial_tap_vals) + else: + max_storage_size = maximum(steps_needed, tap_spread) + res = out_storage[:max_storage_size] + + return cast("TensorVariable", res) + + @node_rewriter([Scan]) def save_mem_new_scan(fgraph, node): r"""Graph optimizer that reduces scan memory consumption. @@ -1398,13 +1467,16 @@ def save_mem_new_scan(fgraph, node): # by the inner function .. ) replaced_outs = [] offset = 1 + op_info.n_seqs + op_info.n_mit_mot - for idx, _val in enumerate(store_steps[op_info.n_mit_mot :]): + for idx, steps_needed in enumerate(store_steps[op_info.n_mit_mot :]): i = idx + op_info.n_mit_mot - if not (isinstance(_val, int) and _val <= 0 and i not in required): - if idx + op_info.n_mit_mot in required: - val = 1 - else: - val = _val + if not ( + isinstance(steps_needed, int) + and steps_needed <= 0 + and i not in required + ): + if i in required: + steps_needed = 1 + # If the memory for this output has been pre-allocated # before going into the scan op (by an alloc node) if idx < op_info.n_mit_sot + op_info.n_sit_sot: @@ -1413,38 +1485,18 @@ def save_mem_new_scan(fgraph, node): # a) the input is a set_subtensor, in that case we # can replace the initial tensor by a slice, # b) it is not, and we simply take a slice of it. - # TODO: commit change below with Razvan - if ( - nw_inputs[offset + idx].owner - and isinstance(nw_inputs[offset + idx].owner.op, IncSubtensor) - and isinstance( - nw_inputs[offset + idx].owner.op.idx_list[0], slice - ) - ): - assert isinstance( - nw_inputs[offset + idx].owner.op, IncSubtensor - ) - _nw_input = nw_inputs[offset + idx].owner.inputs[1] - cval = at.as_tensor_variable(val) - initl = at.as_tensor_variable(init_l[i]) - tmp_idx = at.switch(cval < initl, cval + initl, cval - initl) - nw_input = expand_empty(_nw_input, tmp_idx) - else: - tmp = at.as_tensor_variable(val) - initl = at.as_tensor_variable(init_l[i]) - tmp = maximum(tmp, initl) - nw_input = nw_inputs[offset + idx][:tmp] + out_storage = nw_inputs[offset + idx] + tap_spread = init_l[i] + nw_input = reshape_output_storage( + out_storage, steps_needed, tap_spread + ) nw_inputs[offset + idx] = nw_input - replaced_outs.append(op_info.n_mit_mot + idx) - odx = op_info.n_mit_mot + idx + replaced_outs.append(i) old_outputs += [ ( - odx, - [ - x[0].outputs[0] - for x in fgraph.clients[node.outputs[odx]] - ], + i, + [x[0].outputs[0] for x in fgraph.clients[node.outputs[i]]], ) ] # If there is no memory pre-allocated for this output @@ -1457,48 +1509,28 @@ def save_mem_new_scan(fgraph, node): + op_info.n_shared_outs ) if nw_inputs[pos] == node.inputs[0]: - nw_inputs[pos] = val - odx = op_info.n_mit_mot + idx - replaced_outs.append(odx) + nw_inputs[pos] = steps_needed + replaced_outs.append(i) old_outputs += [ ( - odx, - [ - x[0].outputs[0] - for x in fgraph.clients[node.outputs[odx]] - ], + i, + [x[0].outputs[0] for x in fgraph.clients[node.outputs[i]]], ) ] # 3.4. Recompute inputs for everything else based on the new # number of steps if global_nsteps is not None: - for idx, val in enumerate(store_steps[op_info.n_mit_mot :]): - if val == 0: - # val == 0 means that we want to keep all intermediate + for idx, steps_needed in enumerate(store_steps[op_info.n_mit_mot :]): + if steps_needed == 0: + # steps_needed == 0 means that we want to keep all intermediate # results for that state, including the initial values. if idx < op_info.n_mit_sot + op_info.n_sit_sot: in_idx = offset + idx - # Number of steps in the initial state - initl = init_l[op_info.n_mit_mot + idx] - - # If the initial buffer has the form - # inc_subtensor(zeros(...)[...], _nw_input) - # we want to make the zeros tensor as small as - # possible (nw_steps + initl), and call - # inc_subtensor on that instead. - # Otherwise, simply take 0:(nw_steps+initl). - if ( - nw_inputs[in_idx].owner - and isinstance(nw_inputs[in_idx].owner.op, IncSubtensor) - and isinstance( - nw_inputs[in_idx].owner.op.idx_list[0], slice - ) - ): - _nw_input = nw_inputs[in_idx].owner.inputs[1] - nw_input = expand_empty(_nw_input, nw_steps) - nw_inputs[in_idx] = nw_input - else: - nw_input = nw_inputs[in_idx][: (initl + nw_steps)] + out_storage = nw_inputs[in_idx] + tap_spread = init_l[op_info.n_mit_mot + idx] + nw_input = reshape_output_storage( + out_storage, steps_needed, tap_spread + ) elif ( idx < op_info.n_mit_sot + op_info.n_sit_sot + op_info.n_nit_sot diff --git a/tests/scan/test_basic.py b/tests/scan/test_basic.py index c0371d4338..9936b34c9b 100644 --- a/tests/scan/test_basic.py +++ b/tests/scan/test_basic.py @@ -2950,22 +2950,6 @@ def rec_fn(*args): utt.assert_allclose(outs[2], v_w + 3) utt.assert_allclose(sh.get_value(), v_w + 4) - def test_seq_tap_bug_jeremiah(self): - inp = np.arange(10).reshape(-1, 1).astype(config.floatX) - exp_out = np.zeros((10, 1)).astype(config.floatX) - exp_out[4:] = inp[:-4] - - def onestep(x, x_tm4): - return x, x_tm4 - - seq = matrix() - initial_value = shared(np.zeros((4, 1), dtype=config.floatX)) - outputs_info = [OrderedDict([("initial", initial_value), ("taps", [-4])]), None] - results, updates = scan(fn=onestep, sequences=seq, outputs_info=outputs_info) - - f = function([seq], results[1]) - assert np.all(exp_out == f(inp)) - def test_shared_borrow(self): """ This tests two things. The first is a bug occurring when scan wrongly diff --git a/tests/scan/test_rewriting.py b/tests/scan/test_rewriting.py index 20c1844fdd..ee85a7588a 100644 --- a/tests/scan/test_rewriting.py +++ b/tests/scan/test_rewriting.py @@ -1392,6 +1392,152 @@ def f_pow2(x_tm1): rng = np.random.default_rng(utt.fetch_seed()) my_f(rng.uniform(size=(3,)), 4, np.int64([2, 2, 3])) + def test_init_vals_no_broadcast_case(self): + """The output storage is a `set_subtensor` with initial tap values that don't need to be broadcasted.""" + m, n = (10, 2) + A = at.zeros((m, n), dtype=config.floatX) + + taps = (-1,) + initial = at.arange(len(taps) * n, dtype=config.floatX).reshape((len(taps), n)) + + def fn(i0, im1): + return at.switch(im1, 0, i0) + + out, _ = aesara.scan( + fn=fn, + sequences=A, + outputs_info={"taps": taps, "initial": initial}, + ) + + unopt_output_storage = out.owner.inputs[0].owner.inputs[-1] + + # The unoptimized output storage is `len(taps) + n_steps` (i.e. 11) + assert unopt_output_storage.shape[0].eval() == 11 + + f = function([], out, mode=self.mode) + + out_val = f() + exp_out = function( + [], + out, + mode=get_default_mode().excluding( + "scan", "scan_save_mem", "save_mem_new_scan" + ), + )() + assert np.array_equal(exp_out, out_val) + + scan_node = next( + n for n in f.maker.fgraph.apply_nodes if isinstance(n.op, Scan) + ) + opt_output_storage = scan_node.inputs[1] + + # The optimized output storage only needs `n_steps` (i.e. 10) of + # storage, as long as `len(n_steps) > len(taps)`. + assert function([], opt_output_storage.shape[0], accept_inplace=True)() <= 10 + + def test_init_vals_broadcast_case(self): + """The output storage is a `set_subtensor` with initial tap values that must be broadcasted.""" + m, n = (10, 2) + A = at.zeros((m, n), dtype=config.floatX) + + taps = (-1,) + # Because we're using `zeros` here, `save_mem_new_scan` can get an + # output storage array as follows: + # |IncSubtensor{Set;:int64:} [id C] (outer_in_mit_sot-0) + # |AllocEmpty{dtype='float64'} [id D] + # | |TensorConstant{11} [id E] + # | |TensorConstant{2} [id F] + # |TensorConstant{0.0} [id G] + # |ScalarConstant{1} [id H] + # The `Subtensor` "set" value is a scalar (i.e. id G), so it can't be naively + # reused by `save_mem_new_scan`; it needs the implicit broadcast dimensions. + initial = at.zeros((len(taps), n), dtype=config.floatX) + + def fn(i0, im1): + return at.switch(im1, 0, i0) + + out, _ = aesara.scan( + fn=fn, + sequences=A, + outputs_info={"taps": taps, "initial": initial}, + ) + + unopt_output_storage = out.owner.inputs[0].owner.inputs[-1] + + # The unoptimized output storage is `tap_spread + n_steps` (i.e. 11) + assert unopt_output_storage.shape[0].eval() == 11 + + f = function(inputs=[], outputs=out, mode=self.mode) + + out_val = f() + exp_out = function( + [], + out, + mode=get_default_mode().excluding( + "scan", "scan_save_mem", "save_mem_new_scan" + ), + )() + assert np.array_equal(exp_out, out_val) + + scan_node = next( + n for n in f.maker.fgraph.apply_nodes if isinstance(n.op, Scan) + ) + opt_output_storage = scan_node.inputs[1] + + # The optimized output storage only needs `n_steps` (i.e. 10) of + # storage, as long as `len(n_steps) > len(taps)`. + assert function([], opt_output_storage.shape[0], accept_inplace=True)() <= 10 + + @pytest.mark.parametrize( + "n_steps, tap_spread, opt_storage_shapes", + [ + # TODO FIXME: The first storage array shape seems too large. + (3, 4, (9, 3)), + (10, 4, (9, 10)), + ], + ) + def test_n_steps_taps(self, n_steps, tap_spread, opt_storage_shapes): + inp = np.arange(n_steps).reshape(-1, 1).astype(config.floatX) + + def onestep(x, x_tm4): + return x, x_tm4 + + seq = at.as_tensor(inp) + initial_value = at.as_tensor( + np.arange(-tap_spread, 0, dtype=config.floatX).reshape((tap_spread, 1)) + ) + + results, _ = scan( + fn=onestep, + sequences=seq, + outputs_info=[{"initial": initial_value, "taps": [-tap_spread]}, None], + ) + + # TODO: Alternate the tested output. + f = function([], results[1], mode=self.mode) + out = f() + + exp_out = function( + [], + results[1], + mode=get_default_mode().excluding( + "scan", "scan_save_mem", "save_mem_new_scan" + ), + )() + assert np.array_equal(exp_out, out) + + scan_node = next( + n for n in f.maker.fgraph.apply_nodes if isinstance(n.op, Scan) + ) + + storage_shape_1 = function( + [], scan_node.inputs[-2].shape[0], accept_inplace=True + )() + assert storage_shape_1 <= opt_storage_shapes[0] + + storage_shape_2 = function([], scan_node.inputs[-1], accept_inplace=True)() + assert storage_shape_2 <= opt_storage_shapes[1] + def test_inner_replace_dot(): """