Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace uses of Rebroadcast by SpecifyShape #915

2 changes: 1 addition & 1 deletion aesara/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def _as_symbolic(x, **kwargs) -> Variable:
def get_scalar_constant_value(v):
"""Return the constant scalar (i.e. 0-D) value underlying variable `v`.

If `v` is the output of dim-shuffles, fills, allocs, rebroadcasts, cast
If `v` is the output of dim-shuffles, fills, allocs, cast, etc.
this function digs through them.

If ``aesara.sparse`` is also there, we will look over CSM `Op`.
Expand Down
4 changes: 2 additions & 2 deletions aesara/compile/function/pfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,8 @@ def clone_inputs(i):
err_sug = (
"If the difference is related to the broadcast pattern,"
" you can call the"
" tensor.unbroadcast(var, axis_to_unbroadcast[, ...])"
" function to remove broadcastable dimensions."
" tensor.shape.unbroadcast(var, axis_to_unbroadcast[, ...])"
" function to mask broadcastable dimensions."
)

raise TypeError(err_msg, err_sug)
Expand Down
5 changes: 2 additions & 3 deletions aesara/ifelse.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@
from aesara.graph.basic import Apply, Variable, clone_replace, is_in_ancestors
from aesara.graph.op import _NoPythonOp
from aesara.graph.opt import GlobalOptimizer, in2out, local_optimizer
from aesara.tensor import basic
from aesara.tensor.shape import Reshape, Shape, SpecifyShape
from aesara.tensor.shape import Reshape, Shape, SpecifyShape, Unbroadcast


__docformat__ = "restructedtext en"
Expand Down Expand Up @@ -451,7 +450,7 @@ def cond_make_inplace(fgraph, node):
Shape,
SpecifyShape,
Reshape,
basic.Rebroadcast,
Unbroadcast,
at.math.Dot,
at.math.MaxAndArgmax,
at.subtensor.Subtensor,
Expand Down
19 changes: 5 additions & 14 deletions aesara/link/jax/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
Eye,
Join,
MakeVector,
Rebroadcast,
ScalarFromTensor,
TensorFromScalar,
)
Expand All @@ -50,7 +49,7 @@
from aesara.tensor.nlinalg import SVD, Det, Eig, Eigh, MatrixInverse, QRFull
from aesara.tensor.nnet.basic import LogSoftmax, Softmax, SoftmaxGrad
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, Unbroadcast
from aesara.tensor.slinalg import Cholesky, Solve, SolveTriangular
from aesara.tensor.subtensor import (
AdvancedIncSubtensor,
Expand Down Expand Up @@ -347,20 +346,12 @@ def specifyshape(x, *shape):
return specifyshape


@jax_funcify.register(Rebroadcast)
def jax_funcify_Rebroadcast(op, **kwargs):
op_axis = op.axis

def rebroadcast(x):
for axis, value in op_axis.items():
if value and x.shape[axis] != 1:
raise ValueError(
"Dimension %s in Rebroadcast's input was"
" supposed to be 1 (got %s instead)" % (axis, x.shape[axis])
)
@jax_funcify.register(Unbroadcast)
def jax_funcify_Unbroadcast(op, **kwargs):
def unbroadcast(x):
return x

return rebroadcast
return unbroadcast


@jax_funcify.register(ViewOp)
Expand Down
19 changes: 5 additions & 14 deletions aesara/link/numba/dispatch/tensor_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
Eye,
Join,
MakeVector,
Rebroadcast,
ScalarFromTensor,
TensorFromScalar,
)
from aesara.tensor.shape import Unbroadcast


@numba_funcify.register(AllocEmpty)
Expand Down Expand Up @@ -195,22 +195,13 @@ def makevector({", ".join(input_names)}):
return numba_basic.numba_njit(makevector_fn)


@numba_funcify.register(Rebroadcast)
def numba_funcify_Rebroadcast(op, **kwargs):
# Make sure op_axis only has ints. This way we can avoid literal_unroll
# which causes a segfault, see GH issue https://github.com/numba/numba/issues/8215
op_axis = tuple((axis, int(value)) for axis, value in op.axis.items())

@numba_funcify.register(Unbroadcast)
def numba_funcify_Unbroadcast(op, **kwargs):
@numba_basic.numba_njit
def rebroadcast(x):
for axis, value in op_axis:
if value and x.shape[axis] != 1:
raise ValueError(
("Dimension in Rebroadcast's input was supposed to be 1")
)
def unbroadcast(x):
return x

return rebroadcast
return unbroadcast


@numba_funcify.register(TensorFromScalar)
Expand Down
8 changes: 4 additions & 4 deletions aesara/scan/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from aesara.tensor.basic import get_scalar_constant_value
from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.math import minimum
from aesara.tensor.shape import shape_padleft
from aesara.tensor.shape import shape_padleft, unbroadcast
from aesara.tensor.type import TensorType, integer_dtypes
from aesara.updates import OrderedUpdates

Expand Down Expand Up @@ -751,7 +751,7 @@ def wrap_into_list(x):
# defined in scan utils
sit_sot_scan_inputs.append(
expand_empty(
at.unbroadcast(shape_padleft(actual_arg), 0),
unbroadcast(shape_padleft(actual_arg), 0),
actual_n_steps,
)
)
Expand Down Expand Up @@ -881,7 +881,7 @@ def wrap_into_list(x):
# this will represent only a slice and it will have one
# dimension less.
if isinstance(inner_out.type, TensorType) and return_steps.get(pos, 0) != 1:
outputs[pos] = at.unbroadcast(shape_padleft(inner_out), 0)
outputs[pos] = unbroadcast(shape_padleft(inner_out), 0)

if not return_list and len(outputs) == 1:
outputs = outputs[0]
Expand Down Expand Up @@ -1010,7 +1010,7 @@ def wrap_into_list(x):
sit_sot_inner_inputs.append(new_var)
sit_sot_scan_inputs.append(
expand_empty(
at.unbroadcast(shape_padleft(input.variable), 0),
unbroadcast(shape_padleft(input.variable), 0),
actual_n_steps,
)
)
Expand Down
2 changes: 1 addition & 1 deletion aesara/scan/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def check_broadcast(v1, v2):
"dimension is fixed to 1 in the input, while it is still "
"variable in the output, or vice-verca. You have to make "
"them consistent, e.g. using aesara.tensor."
"{patternbroadcast,unbroadcast,addbroadcast}."
"{unbroadcast, specify_broadcastable}."
)
size = min(len(v1.broadcastable), len(v2.broadcastable))
for n, (b1, b2) in enumerate(
Expand Down
10 changes: 6 additions & 4 deletions aesara/sparse/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
tanh,
trunc,
)
from aesara.tensor.shape import shape
from aesara.tensor.shape import shape, specify_broadcastable
from aesara.tensor.type import TensorType
from aesara.tensor.type import continuous_dtypes as tensor_continuous_dtypes
from aesara.tensor.type import discrete_dtypes as tensor_discrete_dtypes
Expand Down Expand Up @@ -1136,7 +1136,9 @@ def grad(self, inputs, gout):
(x,) = inputs
(gz,) = gout
gx = dense_from_sparse(gz)
gx = at.patternbroadcast(gx, x.broadcastable)
gx = specify_broadcastable(
gx, *(ax for (ax, b) in enumerate(x.type.broadcastable) if b)
)
return (gx,)

def infer_shape(self, fgraph, node, shapes):
Expand Down Expand Up @@ -1900,9 +1902,9 @@ def grad(self, inputs, gout):
else:
ones = at.ones_like(x)
if self.axis == 0:
r = at.addbroadcast(gz.dimshuffle("x", 0), 0) * ones
r = specify_broadcastable(gz.dimshuffle("x", 0), 0) * ones
elif self.axis == 1:
r = at.addbroadcast(gz.dimshuffle(0, "x"), 1) * ones
r = specify_broadcastable(gz.dimshuffle(0, "x"), 1) * ones
else:
raise ValueError("Illegal value for self.axis.")
r = SparseFromDense(o_format)(r)
Expand Down
10 changes: 2 additions & 8 deletions aesara/sparse/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
usmm,
)
from aesara.tensor import blas
from aesara.tensor.basic import as_tensor_variable, cast, patternbroadcast
from aesara.tensor.basic import as_tensor_variable, cast
from aesara.tensor.basic_opt import register_canonicalize, register_specialize
from aesara.tensor.math import mul, neg, sub
from aesara.tensor.shape import shape, specify_shape
Expand All @@ -42,13 +42,7 @@ def local_csm_properties_csm(fgraph, node):
if node.op == csm_properties:
(csm,) = node.inputs
if csm.owner and (csm.owner.op == CSC or csm.owner.op == CSR):
# csm.owner.inputs could be broadcastable. In that case, we have
# to adjust the broadcasting flag here.
ret_var = [
patternbroadcast(i, o.broadcastable)
for i, o in zip(csm.owner.inputs, node.outputs)
]
return ret_var
return csm.owner.inputs

return False

Expand Down