Skip to content

Commit

Permalink
refine scatter signature
Browse files Browse the repository at this point in the history
  • Loading branch information
doombeaker committed Aug 4, 2021
1 parent 7eb881a commit 7e429c0
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 23 deletions.
44 changes: 22 additions & 22 deletions oneflow/api/python/functional/python_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,38 +148,38 @@ py::object PySub(py::args py_args, py::kwargs py_kwargs) {
}

py::object PyScatter(py::args py_args, py::kwargs py_kwargs) {
// "Tensor DimScatter(Tensor input, Tensor index, Tensor src, *, Int32 dim)"
// "Tensor DimScatter(Tensor input, Tensor index, *, Float src, Int32 dim)"
// Scatter(Tensor input, Int32 dim, Tensor index, Tensor src)
// Scatter(Tensor input, Int32 dim, Tensor index, float src)

PyObject* args = py_args.ptr();
PyObject* kwargs = py_kwargs.ptr();
size_t nargs = PyTuple_Size(args);
CHECK_EQ_OR_THROW(nargs, 4) << "4 positional inputs are required.";

const auto& result = [&]() -> Maybe<Tensor> { // NOLINT
Optional<Scalar> dim;
if (auto* dim_obj = PyDict_GetItemString(kwargs, "dim")) {
dim = *JUST(PyUnpackScalar(dim_obj));
}
PyObject* input = PyTuple_GetItem(args, 0);
PyObject* index = PyTuple_GetItem(args, 1);
PyObject* dim = PyTuple_GetItem(args, 1);
PyObject* index = PyTuple_GetItem(args, 2);
PyObject* src = PyTuple_GetItem(args, 3);

const auto& in = JUST(PyUnpackTensor(input));

Optional<Scalar> dim_scalar;
dim_scalar = *JUST(PyUnpackScalar(dim));
Scalar& dim_value = *JUST(dim_scalar.value());
int32_t d = JUST(dim_value.As<int32_t>());

const auto& idx = JUST(PyUnpackTensor(index));

if (nargs == 3) {
PyObject* src = PyTuple_GetItem(args, 2);
const auto& src_tensor = JUST(PyUnpackTensor(src));
bool is_src_tensor = PyTensorCheck(src);

return functional::DimScatter(in, idx, src_tensor, dim);
} else if (nargs == 2) {
Optional<Scalar> src;
if (auto* src_obj = PyDict_GetItemString(kwargs, "src")) {
src = *JUST(PyUnpackScalar(src_obj));
}
Scalar& src_scalar = *JUST(src.value());
return functional::DimScatterUpdateScalar(in, idx, JUST(src_scalar.As<float>()), dim);
if (is_src_tensor) {
const auto& src_tensor = JUST(PyUnpackTensor(src));
return functional::DimScatter(in, idx, src_tensor, d);
} else {
UNIMPLEMENTED_THEN_RETURN() << "none of:\n"
"(Tensor input, Tensor index, Tensor src, *, Int32 dim)"
"(Tensor input, Tensor index, *, Float src, Int32 dim)";
Optional<Scalar> src_scalar;

This comment has been minimized.

Copy link
@hjchen2

hjchen2 Aug 4, 2021

Contributor

这里的src_scalar不是optional的,就不需要用Optional < Scalar >了,直接

Scalar src_scalar = *JUST(PyUnpackScalar(src));

就好了

src_scalar = *JUST(PyUnpackScalar(src));
Scalar& src_value = *JUST(src_scalar.value());
return functional::DimScatterUpdateScalar(in, idx, JUST(src_value.As<float>()), d);
}
}();
return py::cast(result.GetPtrOrThrow());
Expand Down
49 changes: 48 additions & 1 deletion python/oneflow/nn/modules/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,54 @@
from oneflow.nn.module import Module


__all__ = ["scatter_add"]
__all__ = ["scatter", "scatter_add"]


def scatter(input, dim, index, src):
r"""This operator writes the elements specified by `index` along with the axis
`dim` from the `src` into the `input`.
Take a 3-D blob as example, the output is specified by:
.. code-block:: python
input[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0
input[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1
input[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2
input, index and src (if it is a Tensor) should all have the same number of dimensions.
It is also required that index.shape(d) <= src.shape(d) for all dimensions d,
and that index.shape(d) <= self.shape(d) for all dimensions d != dim.
Note that index and src do not broadcast.
Args:
input (Tensor): The input blob.
dim (int): The axis along which to index
index (Tensor): The index blob of elements to scatter.
src (Tensor or float): The source blob whose elements will be scatterd and updated to output.
Returns:
Tensor: The scatterd Tensor.
For example:
.. code-block:: python
>>> import oneflow as flow
>>> import numpy as np
>>> input = flow.ones((3,5))*2
>>> index = flow.tensor(np.array([[0,1,2],[0,1,4]], ), dtype=flow.int32)
>>> src = flow.Tensor(np.array([[0,10,20,30,40],[50,60,70,80,90]]))
>>> out = flow.scatter(input, 1, index, src)
>>> out
tensor([[ 0., 10., 20., 2., 2.],
[50., 60., 2., 2., 70.],
[ 2., 2., 2., 2., 2.]], dtype=oneflow.float32)
"""

return flow.F.scatter(input, dim, index, src)


def scatter_add(input, dim, index, src):
Expand Down

0 comments on commit 7e429c0

Please sign in to comment.