Skip to content

Commit

Permalink
Fix bad initial value shape assumptions in save_mem_new_scan
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Jun 3, 2023
1 parent 6cb7a38 commit 2368ed3
Show file tree
Hide file tree
Showing 3 changed files with 247 additions and 85 deletions.
170 changes: 101 additions & 69 deletions aesara/scan/rewriting.py
Expand Up @@ -4,7 +4,7 @@
import dataclasses
from itertools import chain
from sys import maxsize
from typing import Dict, List, Optional, Tuple, cast
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast

import numpy as np

Expand Down Expand Up @@ -36,7 +36,6 @@
from aesara.scan.utils import (
ScanArgs,
compress_outs,
expand_empty,
reconstruct_graph,
safe_new,
scan_can_remove_outs,
Expand All @@ -48,18 +47,22 @@
from aesara.tensor.rewriting.basic import constant_folding, local_useless_switch
from aesara.tensor.rewriting.elemwise import local_upcast_elemwise_constant_inputs
from aesara.tensor.rewriting.math import local_abs_merge, local_mul_switch_sink
from aesara.tensor.shape import shape
from aesara.tensor.shape import shape, shape_tuple
from aesara.tensor.subtensor import (
IncSubtensor,
Subtensor,
get_canonical_form_slice,
get_idx_list,
get_slice_elements,
indices_from_subtensor,
set_subtensor,
)
from aesara.tensor.var import TensorConstant, get_unique_value


if TYPE_CHECKING:
from aesara.tensor.var import TensorVariable

list_opt_slice = [
local_abs_merge,
local_mul_switch_sink,
Expand Down Expand Up @@ -1103,6 +1106,72 @@ def sanitize(x):
return at.as_tensor_variable(x)


def reshape_output_storage(
out_storage: "TensorVariable",
steps_needed: "TensorVariable",
tap_spread: int,
) -> "TensorVariable":
"""Resize the first dimension of ``storage`` in ``set_subtensor(storage[:tap_spread], initial_tap_vals)``.
This is used by `save_mem_new_scan` to reduce the amount of storage
(pre)allocated for `Scan` output arrays (i.e. ``storage`` is assumed to be
an `AllocEmpty` output).
Parameters
----------
out_storage
This corresponds to a graph with the form
``set_subtensor(storage[:tap_spread], initial_tap_vals)``.
tap_spread
The spread of the relevant tap. This will generally be the length of
``initial_tap_vals``, but sometimes not (e.g. because the initial
values broadcast across the indices/slice).
Returns
-------
Return a graph like
``set_subtensor(new_storage[:tap_spread], initial_tap_vals)``,
where ``new_storage`` is an `AllocEmpty` with a first
dimension having length ``maximum(steps_needed_var, tap_spread)``.
"""
out_storage_node = out_storage.owner
if (
out_storage_node
and isinstance(out_storage_node.op, IncSubtensor)
and out_storage_node.op.set_instead_of_inc
and len(out_storage_node.op.idx_list) == 1
and isinstance(out_storage_node.op.idx_list[0], slice)
):
# The "fill" value of the `IncSubtensor` across the
# slice. This should generally consist of the initial
# values.
initial_tap_vals = out_storage_node.inputs[1]

storage_slice = indices_from_subtensor(
out_storage_node.inputs[2:], out_storage_node.op.idx_list
)
inner_storage_var = out_storage_node.inputs[0]

# Why this size exactly? (N.B. This is what the original Theano logic ultimately did.)
max_storage_size = at.switch(
at.lt(steps_needed, tap_spread), steps_needed + 2 * tap_spread, steps_needed
)
new_inner_storage_var = at.empty(
(
max_storage_size,
*shape_tuple(inner_storage_var)[1:],
),
dtype=initial_tap_vals.dtype,
)
res = at.set_subtensor(new_inner_storage_var[storage_slice], initial_tap_vals)
else:
max_storage_size = maximum(steps_needed, tap_spread)
res = out_storage[:max_storage_size]

return cast("TensorVariable", res)


@node_rewriter([Scan])
def save_mem_new_scan(fgraph, node):
r"""Graph optimizer that reduces scan memory consumption.
Expand Down Expand Up @@ -1398,13 +1467,16 @@ def save_mem_new_scan(fgraph, node):
# by the inner function .. )
replaced_outs = []
offset = 1 + op_info.n_seqs + op_info.n_mit_mot
for idx, _val in enumerate(store_steps[op_info.n_mit_mot :]):
for idx, steps_needed in enumerate(store_steps[op_info.n_mit_mot :]):
i = idx + op_info.n_mit_mot
if not (isinstance(_val, int) and _val <= 0 and i not in required):
if idx + op_info.n_mit_mot in required:
val = 1
else:
val = _val
if not (
isinstance(steps_needed, int)
and steps_needed <= 0
and i not in required
):
if i in required:
steps_needed = 1

# If the memory for this output has been pre-allocated
# before going into the scan op (by an alloc node)
if idx < op_info.n_mit_sot + op_info.n_sit_sot:
Expand All @@ -1413,38 +1485,18 @@ def save_mem_new_scan(fgraph, node):
# a) the input is a set_subtensor, in that case we
# can replace the initial tensor by a slice,
# b) it is not, and we simply take a slice of it.
# TODO: commit change below with Razvan
if (
nw_inputs[offset + idx].owner
and isinstance(nw_inputs[offset + idx].owner.op, IncSubtensor)
and isinstance(
nw_inputs[offset + idx].owner.op.idx_list[0], slice
)
):
assert isinstance(
nw_inputs[offset + idx].owner.op, IncSubtensor
)
_nw_input = nw_inputs[offset + idx].owner.inputs[1]
cval = at.as_tensor_variable(val)
initl = at.as_tensor_variable(init_l[i])
tmp_idx = at.switch(cval < initl, cval + initl, cval - initl)
nw_input = expand_empty(_nw_input, tmp_idx)
else:
tmp = at.as_tensor_variable(val)
initl = at.as_tensor_variable(init_l[i])
tmp = maximum(tmp, initl)
nw_input = nw_inputs[offset + idx][:tmp]
out_storage = nw_inputs[offset + idx]
tap_spread = init_l[i]
nw_input = reshape_output_storage(
out_storage, steps_needed, tap_spread
)

nw_inputs[offset + idx] = nw_input
replaced_outs.append(op_info.n_mit_mot + idx)
odx = op_info.n_mit_mot + idx
replaced_outs.append(i)
old_outputs += [
(
odx,
[
x[0].outputs[0]
for x in fgraph.clients[node.outputs[odx]]
],
i,
[x[0].outputs[0] for x in fgraph.clients[node.outputs[i]]],
)
]
# If there is no memory pre-allocated for this output
Expand All @@ -1457,48 +1509,28 @@ def save_mem_new_scan(fgraph, node):
+ op_info.n_shared_outs
)
if nw_inputs[pos] == node.inputs[0]:
nw_inputs[pos] = val
odx = op_info.n_mit_mot + idx
replaced_outs.append(odx)
nw_inputs[pos] = steps_needed
replaced_outs.append(i)
old_outputs += [
(
odx,
[
x[0].outputs[0]
for x in fgraph.clients[node.outputs[odx]]
],
i,
[x[0].outputs[0] for x in fgraph.clients[node.outputs[i]]],
)
]
# 3.4. Recompute inputs for everything else based on the new
# number of steps
if global_nsteps is not None:
for idx, val in enumerate(store_steps[op_info.n_mit_mot :]):
if val == 0:
# val == 0 means that we want to keep all intermediate
for idx, steps_needed in enumerate(store_steps[op_info.n_mit_mot :]):
if steps_needed == 0:
# steps_needed == 0 means that we want to keep all intermediate
# results for that state, including the initial values.
if idx < op_info.n_mit_sot + op_info.n_sit_sot:
in_idx = offset + idx
# Number of steps in the initial state
initl = init_l[op_info.n_mit_mot + idx]

# If the initial buffer has the form
# inc_subtensor(zeros(...)[...], _nw_input)
# we want to make the zeros tensor as small as
# possible (nw_steps + initl), and call
# inc_subtensor on that instead.
# Otherwise, simply take 0:(nw_steps+initl).
if (
nw_inputs[in_idx].owner
and isinstance(nw_inputs[in_idx].owner.op, IncSubtensor)
and isinstance(
nw_inputs[in_idx].owner.op.idx_list[0], slice
)
):
_nw_input = nw_inputs[in_idx].owner.inputs[1]
nw_input = expand_empty(_nw_input, nw_steps)
nw_inputs[in_idx] = nw_input
else:
nw_input = nw_inputs[in_idx][: (initl + nw_steps)]
out_storage = nw_inputs[in_idx]
tap_spread = init_l[op_info.n_mit_mot + idx]
nw_input = reshape_output_storage(
out_storage, steps_needed, tap_spread
)

elif (
idx < op_info.n_mit_sot + op_info.n_sit_sot + op_info.n_nit_sot
Expand Down
16 changes: 0 additions & 16 deletions tests/scan/test_basic.py
Expand Up @@ -2950,22 +2950,6 @@ def rec_fn(*args):
utt.assert_allclose(outs[2], v_w + 3)
utt.assert_allclose(sh.get_value(), v_w + 4)

def test_seq_tap_bug_jeremiah(self):
inp = np.arange(10).reshape(-1, 1).astype(config.floatX)
exp_out = np.zeros((10, 1)).astype(config.floatX)
exp_out[4:] = inp[:-4]

def onestep(x, x_tm4):
return x, x_tm4

seq = matrix()
initial_value = shared(np.zeros((4, 1), dtype=config.floatX))
outputs_info = [OrderedDict([("initial", initial_value), ("taps", [-4])]), None]
results, updates = scan(fn=onestep, sequences=seq, outputs_info=outputs_info)

f = function([seq], results[1])
assert np.all(exp_out == f(inp))

def test_shared_borrow(self):
"""
This tests two things. The first is a bug occurring when scan wrongly
Expand Down

0 comments on commit 2368ed3

Please sign in to comment.