Skip to content

Commit

Permalink
Deprecate remaining uses of Rebroadcast in favor of MaskBroadcastable
Browse files Browse the repository at this point in the history
  • Loading branch information
Ricardo committed May 17, 2022
1 parent 5397f4c commit 2cb7222
Show file tree
Hide file tree
Showing 19 changed files with 359 additions and 544 deletions.
2 changes: 1 addition & 1 deletion aesara/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def _as_symbolic(x, **kwargs) -> Variable:
def get_scalar_constant_value(v):
"""Return the constant scalar (i.e. 0-D) value underlying variable `v`.
If `v` is the output of dim-shuffles, fills, allocs, rebroadcasts, cast
If `v` is the output of dim-shuffles, fills, allocs, cast, etc.
this function digs through them.
If ``aesara.sparse`` is also there, we will look over CSM `Op`.
Expand Down
4 changes: 2 additions & 2 deletions aesara/compile/function/pfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,8 @@ def clone_inputs(i):
err_sug = (
"If the difference is related to the broadcast pattern,"
" you can call the"
" tensor.unbroadcast(var, axis_to_unbroadcast[, ...])"
" function to remove broadcastable dimensions."
" tensor.shape.mask_broadcastable(var, axis_to_unbroadcast[, ...])"
" function to mask broadcastable dimensions."
)

raise TypeError(err_msg, err_sug)
Expand Down
5 changes: 2 additions & 3 deletions aesara/ifelse.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@
from aesara.graph.basic import Apply, Variable, clone_replace, is_in_ancestors
from aesara.graph.op import _NoPythonOp
from aesara.graph.opt import GlobalOptimizer, in2out, local_optimizer
from aesara.tensor import basic
from aesara.tensor.shape import Reshape, Shape, SpecifyShape
from aesara.tensor.shape import MaskBroadcastable, Reshape, Shape, SpecifyShape


__docformat__ = "restructedtext en"
Expand Down Expand Up @@ -451,7 +450,7 @@ def cond_make_inplace(fgraph, node):
Shape,
SpecifyShape,
Reshape,
basic.Rebroadcast,
MaskBroadcastable,
at.math.Dot,
at.math.MaxAndArgmax,
at.subtensor.Subtensor,
Expand Down
19 changes: 5 additions & 14 deletions aesara/link/jax/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
Eye,
Join,
MakeVector,
Rebroadcast,
ScalarFromTensor,
TensorFromScalar,
)
Expand All @@ -50,7 +49,7 @@
from aesara.tensor.nlinalg import SVD, Det, Eig, Eigh, MatrixInverse, QRFull
from aesara.tensor.nnet.basic import LogSoftmax, Softmax, SoftmaxGrad
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
from aesara.tensor.shape import MaskBroadcastable, Reshape, Shape, Shape_i, SpecifyShape
from aesara.tensor.slinalg import Cholesky, Solve, SolveTriangular
from aesara.tensor.subtensor import (
AdvancedIncSubtensor,
Expand Down Expand Up @@ -347,20 +346,12 @@ def specifyshape(x, *shape):
return specifyshape


@jax_funcify.register(Rebroadcast)
def jax_funcify_Rebroadcast(op, **kwargs):
op_axis = op.axis

def rebroadcast(x):
for axis, value in op_axis.items():
if value and x.shape[axis] != 1:
raise ValueError(
"Dimension %s in Rebroadcast's input was"
" supposed to be 1 (got %s instead)" % (axis, x.shape[axis])
)
@jax_funcify.register(MaskBroadcastable)
def jax_funcify_MaskBroadcastable(op, **kwargs):
def mask_broadcastable(x):
return x

return rebroadcast
return mask_broadcastable


@jax_funcify.register(ViewOp)
Expand Down
18 changes: 5 additions & 13 deletions aesara/link/numba/dispatch/tensor_basic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from textwrap import indent

import numba
import numpy as np

from aesara.link.numba.dispatch import basic as numba_basic
Expand All @@ -15,10 +14,10 @@
Eye,
Join,
MakeVector,
Rebroadcast,
ScalarFromTensor,
TensorFromScalar,
)
from aesara.tensor.shape import MaskBroadcastable


@numba_funcify.register(AllocEmpty)
Expand Down Expand Up @@ -196,20 +195,13 @@ def makevector({", ".join(input_names)}):
return numba_basic.numba_njit(makevector_fn)


@numba_funcify.register(Rebroadcast)
def numba_funcify_Rebroadcast(op, **kwargs):
op_axis = tuple(op.axis.items())

@numba_funcify.register(MaskBroadcastable)
def numba_funcify_MaskBroadcastable(op, **kwargs):
@numba_basic.numba_njit
def rebroadcast(x):
for axis, value in numba.literal_unroll(op_axis):
if value and x.shape[axis] != 1:
raise ValueError(
("Dimension in Rebroadcast's input was supposed to be 1")
)
def mask_broadcastable(x):
return x

return rebroadcast
return mask_broadcastable


@numba_funcify.register(TensorFromScalar)
Expand Down
8 changes: 4 additions & 4 deletions aesara/scan/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from aesara.tensor.basic import get_scalar_constant_value
from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.math import minimum
from aesara.tensor.shape import shape_padleft
from aesara.tensor.shape import mask_broadcastable, shape_padleft
from aesara.tensor.type import TensorType, integer_dtypes
from aesara.updates import OrderedUpdates

Expand Down Expand Up @@ -751,7 +751,7 @@ def wrap_into_list(x):
# defined in scan utils
sit_sot_scan_inputs.append(
expand_empty(
at.unbroadcast(shape_padleft(actual_arg), 0),
mask_broadcastable(shape_padleft(actual_arg), 0),
actual_n_steps,
)
)
Expand Down Expand Up @@ -881,7 +881,7 @@ def wrap_into_list(x):
# this will represent only a slice and it will have one
# dimension less.
if isinstance(inner_out.type, TensorType) and return_steps.get(pos, 0) != 1:
outputs[pos] = at.unbroadcast(shape_padleft(inner_out), 0)
outputs[pos] = mask_broadcastable(shape_padleft(inner_out), 0)

if not return_list and len(outputs) == 1:
outputs = outputs[0]
Expand Down Expand Up @@ -1010,7 +1010,7 @@ def wrap_into_list(x):
sit_sot_inner_inputs.append(new_var)
sit_sot_scan_inputs.append(
expand_empty(
at.unbroadcast(shape_padleft(input.variable), 0),
mask_broadcastable(shape_padleft(input.variable), 0),
actual_n_steps,
)
)
Expand Down
2 changes: 1 addition & 1 deletion aesara/scan/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def check_broadcast(v1, v2):
"dimension is fixed to 1 in the input, while it is still "
"variable in the output, or vice-verca. You have to make "
"them consistent, e.g. using aesara.tensor."
"{unbroadcast, specify_broadcastable}."
"{mask_broadcastable, specify_broadcastable}."
)
size = min(len(v1.broadcastable), len(v2.broadcastable))
for n, (b1, b2) in enumerate(
Expand Down

0 comments on commit 2cb7222

Please sign in to comment.