Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove inplace broadcast_add #5551

Merged
merged 2 commits into from
Jul 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
bind_python: True

- name: "broadcast_add"
signature: "Tensor BroadcastAdd(Tensor x, Tensor y, *, Bool inplace=False)"
signature: "Tensor BroadcastAdd(Tensor x, Tensor y)"
bind_python: True

- name: "sub_scalar_by_tensor"
Expand Down
2 changes: 1 addition & 1 deletion oneflow/core/functional/impl/binary_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class PowFunctor : public BinaryFunctor {
}
};

class BroadcastAddFunctor : public InplaceableBinaryFunctor {
class BroadcastAddFunctor : public BinaryFunctor {
public:
BroadcastAddFunctor() {
op_ = CHECK_JUST(one::OpBuilder("broadcast_add").Input("x").Input("y").Output("z").Build());
Expand Down
28 changes: 25 additions & 3 deletions oneflow/python/nn/modules/broadcast_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,43 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

from typing import Optional, Sequence

import oneflow as flow
from oneflow.python.nn.module import Module
from oneflow.python.oneflow_export import oneflow_export, experimental_api


def _calc_broadcast_axes(x, like_tensor):
num_prepend = len(like_tensor.shape) - len(x.shape)
prepend_shape = [1] * num_prepend + list(x.shape)
broadcast_axes = [x for x in range(num_prepend)]
for i in range(num_prepend, len(prepend_shape)):
if prepend_shape[i] != like_tensor.shape[i]:
if prepend_shape[i] != 1:
raise RuntimeError(
f"output with shape {x.shape} doesn't match the broadcast shape {like_tensor.shape}"
)
else:
broadcast_axes.append(i)
return tuple(broadcast_axes)


class BroadCastLike(Module):
def __init__(self, broadcast_axes: None) -> None:
def __init__(self, broadcast_axes: Optional[Sequence] = None) -> None:
super().__init__()
self.broadcast_axes = broadcast_axes

def forward(self, x, like_tensor):
return flow.F.broadcast_like(x, like_tensor, broadcast_axes=self.broadcast_axes)
if self.broadcast_axes is None:
broadcast_axes = _calc_broadcast_axes(x, like_tensor)
else:
broadcast_axes = self.broadcast_axes
return flow.F.broadcast_like(x, like_tensor, broadcast_axes=broadcast_axes)


@oneflow_export("broadcast_like")
@experimental_api
def broadcast_like_op(x, like_tensor, broadcast_axes: None):
def broadcast_like_op(x, like_tensor, broadcast_axes: Optional[Sequence] = None):
return BroadCastLike(broadcast_axes=broadcast_axes)(x, like_tensor)
10 changes: 4 additions & 6 deletions oneflow/python/nn/modules/math_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,14 +391,11 @@ def forward(self, x, y):


class BroadcastAdd(Module):
def __init__(self, inplace: bool = False) -> None:
def __init__(self) -> None:
super().__init__()
self.inplace = inplace

def forward(self, x, y):
if self.inplace:
_check_inplace_valid(x)
return flow.F.broadcast_add(x, y, self.inplace)
return flow.F.broadcast_add(x, y)


@oneflow_export("add")
Expand Down Expand Up @@ -474,7 +471,8 @@ def _add_inplace(x, y):
elif y.shape == (1,):
return ScalarAddByTensor(inplace=True)(x, y)
else:
return BroadcastAdd(inplace=True)(x, y)
y = flow.experimental.broadcast_like(y, x)
return ElementwiseAdd(inplace=True)(x, y)


class Asin(Module):
Expand Down