Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bad initial values shape assumptions in save_mem_new_scan rewrite #1501

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
170 changes: 101 additions & 69 deletions aesara/scan/rewriting.py
Expand Up @@ -4,7 +4,7 @@
import dataclasses
from itertools import chain
from sys import maxsize
from typing import Dict, List, Optional, Tuple, cast
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast

import numpy as np

Expand Down Expand Up @@ -36,7 +36,6 @@
from aesara.scan.utils import (
ScanArgs,
compress_outs,
expand_empty,
reconstruct_graph,
safe_new,
scan_can_remove_outs,
Expand All @@ -48,18 +47,22 @@
from aesara.tensor.rewriting.basic import constant_folding, local_useless_switch
from aesara.tensor.rewriting.elemwise import local_upcast_elemwise_constant_inputs
from aesara.tensor.rewriting.math import local_abs_merge, local_mul_switch_sink
from aesara.tensor.shape import shape
from aesara.tensor.shape import shape, shape_tuple
from aesara.tensor.subtensor import (
IncSubtensor,
Subtensor,
get_canonical_form_slice,
get_idx_list,
get_slice_elements,
indices_from_subtensor,
set_subtensor,
)
from aesara.tensor.var import TensorConstant, get_unique_value


if TYPE_CHECKING:
from aesara.tensor.var import TensorVariable

list_opt_slice = [
local_abs_merge,
local_mul_switch_sink,
Expand Down Expand Up @@ -1103,6 +1106,72 @@ def sanitize(x):
return at.as_tensor_variable(x)


def reshape_output_storage(
out_storage: "TensorVariable",
steps_needed: "TensorVariable",
tap_spread: int,
) -> "TensorVariable":
"""Resize the first dimension of ``storage`` in ``set_subtensor(storage[:tap_spread], initial_tap_vals)``.

This is used by `save_mem_new_scan` to reduce the amount of storage
(pre)allocated for `Scan` output arrays (i.e. ``storage`` is assumed to be
an `AllocEmpty` output).

Parameters
----------
out_storage
This corresponds to a graph with the form
``set_subtensor(storage[:tap_spread], initial_tap_vals)``.
tap_spread
The spread of the relevant tap. This will generally be the length of
``initial_tap_vals``, but sometimes not (e.g. because the initial
values broadcast across the indices/slice).

Returns
-------
Return a graph like
``set_subtensor(new_storage[:tap_spread], initial_tap_vals)``,
where ``new_storage`` is an `AllocEmpty` with a first
dimension having length ``maximum(steps_needed_var, tap_spread)``.

"""
out_storage_node = out_storage.owner
if (
out_storage_node
and isinstance(out_storage_node.op, IncSubtensor)
and out_storage_node.op.set_instead_of_inc
and len(out_storage_node.op.idx_list) == 1
and isinstance(out_storage_node.op.idx_list[0], slice)
):
# The "fill" value of the `IncSubtensor` across the
# slice. This should generally consist of the initial
# values.
initial_tap_vals = out_storage_node.inputs[1]

storage_slice = indices_from_subtensor(
out_storage_node.inputs[2:], out_storage_node.op.idx_list
)
inner_storage_var = out_storage_node.inputs[0]

# Why this size exactly? (N.B. This is what the original Theano logic ultimately did.)
max_storage_size = at.switch(
at.lt(steps_needed, tap_spread), steps_needed + 2 * tap_spread, steps_needed
)
new_inner_storage_var = at.empty(
(
max_storage_size,
*shape_tuple(inner_storage_var)[1:],
),
dtype=initial_tap_vals.dtype,
)
res = at.set_subtensor(new_inner_storage_var[storage_slice], initial_tap_vals)
else:
max_storage_size = maximum(steps_needed, tap_spread)
res = out_storage[:max_storage_size]

return cast("TensorVariable", res)


@node_rewriter([Scan])
def save_mem_new_scan(fgraph, node):
r"""Graph optimizer that reduces scan memory consumption.
Expand Down Expand Up @@ -1398,13 +1467,16 @@ def save_mem_new_scan(fgraph, node):
# by the inner function .. )
replaced_outs = []
offset = 1 + op_info.n_seqs + op_info.n_mit_mot
for idx, _val in enumerate(store_steps[op_info.n_mit_mot :]):
for idx, steps_needed in enumerate(store_steps[op_info.n_mit_mot :]):
i = idx + op_info.n_mit_mot
if not (isinstance(_val, int) and _val <= 0 and i not in required):
if idx + op_info.n_mit_mot in required:
val = 1
else:
val = _val
if not (
isinstance(steps_needed, int)
and steps_needed <= 0
and i not in required
):
if i in required:
steps_needed = 1

# If the memory for this output has been pre-allocated
# before going into the scan op (by an alloc node)
if idx < op_info.n_mit_sot + op_info.n_sit_sot:
Expand All @@ -1413,38 +1485,18 @@ def save_mem_new_scan(fgraph, node):
# a) the input is a set_subtensor, in that case we
# can replace the initial tensor by a slice,
# b) it is not, and we simply take a slice of it.
# TODO: commit change below with Razvan
if (
nw_inputs[offset + idx].owner
and isinstance(nw_inputs[offset + idx].owner.op, IncSubtensor)
and isinstance(
nw_inputs[offset + idx].owner.op.idx_list[0], slice
)
):
assert isinstance(
nw_inputs[offset + idx].owner.op, IncSubtensor
)
_nw_input = nw_inputs[offset + idx].owner.inputs[1]
cval = at.as_tensor_variable(val)
initl = at.as_tensor_variable(init_l[i])
tmp_idx = at.switch(cval < initl, cval + initl, cval - initl)
nw_input = expand_empty(_nw_input, tmp_idx)
else:
tmp = at.as_tensor_variable(val)
initl = at.as_tensor_variable(init_l[i])
tmp = maximum(tmp, initl)
nw_input = nw_inputs[offset + idx][:tmp]
out_storage = nw_inputs[offset + idx]
tap_spread = init_l[i]
nw_input = reshape_output_storage(
out_storage, steps_needed, tap_spread
)

nw_inputs[offset + idx] = nw_input
replaced_outs.append(op_info.n_mit_mot + idx)
odx = op_info.n_mit_mot + idx
replaced_outs.append(i)
old_outputs += [
(
odx,
[
x[0].outputs[0]
for x in fgraph.clients[node.outputs[odx]]
],
i,
[x[0].outputs[0] for x in fgraph.clients[node.outputs[i]]],
)
]
# If there is no memory pre-allocated for this output
Expand All @@ -1457,48 +1509,28 @@ def save_mem_new_scan(fgraph, node):
+ op_info.n_shared_outs
)
if nw_inputs[pos] == node.inputs[0]:
nw_inputs[pos] = val
odx = op_info.n_mit_mot + idx
replaced_outs.append(odx)
nw_inputs[pos] = steps_needed
replaced_outs.append(i)
old_outputs += [
(
odx,
[
x[0].outputs[0]
for x in fgraph.clients[node.outputs[odx]]
],
i,
[x[0].outputs[0] for x in fgraph.clients[node.outputs[i]]],
)
]
# 3.4. Recompute inputs for everything else based on the new
# number of steps
if global_nsteps is not None:
for idx, val in enumerate(store_steps[op_info.n_mit_mot :]):
if val == 0:
# val == 0 means that we want to keep all intermediate
for idx, steps_needed in enumerate(store_steps[op_info.n_mit_mot :]):
if steps_needed == 0:
# steps_needed == 0 means that we want to keep all intermediate
# results for that state, including the initial values.
if idx < op_info.n_mit_sot + op_info.n_sit_sot:
in_idx = offset + idx
# Number of steps in the initial state
initl = init_l[op_info.n_mit_mot + idx]

# If the initial buffer has the form
# inc_subtensor(zeros(...)[...], _nw_input)
# we want to make the zeros tensor as small as
# possible (nw_steps + initl), and call
# inc_subtensor on that instead.
# Otherwise, simply take 0:(nw_steps+initl).
if (
nw_inputs[in_idx].owner
and isinstance(nw_inputs[in_idx].owner.op, IncSubtensor)
and isinstance(
nw_inputs[in_idx].owner.op.idx_list[0], slice
)
):
_nw_input = nw_inputs[in_idx].owner.inputs[1]
nw_input = expand_empty(_nw_input, nw_steps)
nw_inputs[in_idx] = nw_input
else:
nw_input = nw_inputs[in_idx][: (initl + nw_steps)]
out_storage = nw_inputs[in_idx]
tap_spread = init_l[op_info.n_mit_mot + idx]
nw_input = reshape_output_storage(
out_storage, steps_needed, tap_spread
)

elif (
idx < op_info.n_mit_sot + op_info.n_sit_sot + op_info.n_nit_sot
Expand Down
40 changes: 24 additions & 16 deletions tests/scan/test_basic.py
Expand Up @@ -2950,22 +2950,6 @@ def rec_fn(*args):
utt.assert_allclose(outs[2], v_w + 3)
utt.assert_allclose(sh.get_value(), v_w + 4)

def test_seq_tap_bug_jeremiah(self):
inp = np.arange(10).reshape(-1, 1).astype(config.floatX)
exp_out = np.zeros((10, 1)).astype(config.floatX)
exp_out[4:] = inp[:-4]

def onestep(x, x_tm4):
return x, x_tm4

seq = matrix()
initial_value = shared(np.zeros((4, 1), dtype=config.floatX))
outputs_info = [OrderedDict([("initial", initial_value), ("taps", [-4])]), None]
results, updates = scan(fn=onestep, sequences=seq, outputs_info=outputs_info)

f = function([seq], results[1])
assert np.all(exp_out == f(inp))

def test_shared_borrow(self):
"""
This tests two things. The first is a bug occurring when scan wrongly
Expand Down Expand Up @@ -4031,3 +4015,27 @@ def fn(n):
res = f_cvm()

assert np.array_equal(res, np.array([3, 1, 0]))


@pytest.mark.xfail(reason="Need to fix overly strict tensor type checking")
def test_bad_broadcast_check():
inp = np.arange(10).reshape(-1, 1).astype(config.floatX)

def onestep(x, x_tm4):
return x, x_tm4

# This will have a broadcastable last dimension
seq = at.as_tensor(inp)

# This won't, so it will fail
initial_value = shared(np.zeros((4, 1), dtype=config.floatX))

outputs_info = ([{"initial": initial_value, "taps": [-4]}, None],)
results, updates = scan(fn=onestep, sequences=seq, outputs_info=outputs_info)

exp_out = np.zeros((10, 1)).astype(config.floatX)
exp_out[4:] = inp[:-4]

f = function([], results[1])
out = f()
assert np.array_equal(exp_out, out)