Skip to content

Commit

Permalink
Replace uses of Rebroadcast by SpecifyShape
Browse files Browse the repository at this point in the history
  • Loading branch information
Ricardo committed Apr 19, 2022
1 parent 00e0d80 commit f8211eb
Show file tree
Hide file tree
Showing 18 changed files with 201 additions and 528 deletions.
2 changes: 1 addition & 1 deletion aesara/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def _as_symbolic(x, **kwargs) -> Variable:
def get_scalar_constant_value(v):
"""Return the constant scalar (i.e. 0-D) value underlying variable `v`.
If `v` is the output of dim-shuffles, fills, allocs, rebroadcasts, cast
If `v` is the output of dim-shuffles, fills, allocs, specify_shapes, cast
this function digs through them.
If ``aesara.sparse`` is also there, we will look over CSM `Op`.
Expand Down
4 changes: 2 additions & 2 deletions aesara/ifelse.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from aesara.graph.basic import Apply, Variable, clone_replace, is_in_ancestors
from aesara.graph.op import _NoPythonOp
from aesara.graph.opt import GlobalOptimizer, in2out, local_optimizer
from aesara.tensor import basic
from aesara.tensor.shape import Reshape, Shape, SpecifyShape


Expand Down Expand Up @@ -451,7 +450,8 @@ def cond_make_inplace(fgraph, node):
Shape,
SpecifyShape,
Reshape,
basic.Rebroadcast,
# TODO: Check if SpecifyShape can be lifted through where old rebroadcast used to
# basic.Rebroadcast,
at.math.Dot,
at.math.MaxAndArgmax,
at.subtensor.Subtensor,
Expand Down
17 changes: 0 additions & 17 deletions aesara/link/jax/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
Eye,
Join,
MakeVector,
Rebroadcast,
ScalarFromTensor,
TensorFromScalar,
)
Expand Down Expand Up @@ -348,22 +347,6 @@ def specifyshape(x, *shape):
return specifyshape


@jax_funcify.register(Rebroadcast)
def jax_funcify_Rebroadcast(op, **kwargs):
op_axis = op.axis

def rebroadcast(x):
for axis, value in op_axis.items():
if value and x.shape[axis] != 1:
raise ValueError(
"Dimension %s in Rebroadcast's input was"
" supposed to be 1 (got %s instead)" % (axis, x.shape[axis])
)
return x

return rebroadcast


@jax_funcify.register(ViewOp)
def jax_funcify_ViewOp(op, **kwargs):
def viewop(x):
Expand Down
18 changes: 0 additions & 18 deletions aesara/link/numba/dispatch/tensor_basic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from textwrap import indent

import numba
import numpy as np

from aesara.link.numba.dispatch import basic as numba_basic
Expand All @@ -15,7 +14,6 @@
Eye,
Join,
MakeVector,
Rebroadcast,
ScalarFromTensor,
TensorFromScalar,
)
Expand Down Expand Up @@ -196,22 +194,6 @@ def makevector({", ".join(input_names)}):
return numba_basic.numba_njit(makevector_fn)


@numba_funcify.register(Rebroadcast)
def numba_funcify_Rebroadcast(op, **kwargs):
op_axis = tuple(op.axis.items())

@numba_basic.numba_njit
def rebroadcast(x):
for axis, value in numba.literal_unroll(op_axis):
if value and x.shape[axis] != 1:
raise ValueError(
("Dimension in Rebroadcast's input was supposed to be 1")
)
return x

return rebroadcast


@numba_funcify.register(TensorFromScalar)
def numba_funcify_TensorFromScalar(op, **kwargs):
@numba_basic.numba_njit(inline="always")
Expand Down

0 comments on commit f8211eb

Please sign in to comment.