From 1868c19f26d5536bb2cfea7476161cdf9deccd2c Mon Sep 17 00:00:00 2001 From: ZZK <42901638+MARD1NO@users.noreply.github.com> Date: Mon, 13 Sep 2021 10:02:59 +0800 Subject: [PATCH] Dev functional batch_gather (#6233) * add broadcast like docs * unsorted batch segment sum functional * add unittest * add docs * add batch gather docs rst * fix doc code Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> --- docs/source/oneflow.rst | 2 + .../autograd/gradient_funcs/batch_gather.cpp | 18 +--- oneflow/core/framework/op_expr_helper.cpp | 12 --- oneflow/core/framework/op_expr_helper.h | 3 - oneflow/core/functional/functional_api.yaml | 15 ++++ .../core/functional/impl/array_functor.cpp | 38 +++++++++ python/oneflow/__init__.py | 2 +- python/oneflow/framework/docstr/array_ops.py | 46 ++++++++++ python/oneflow/nn/modules/broadcast_like.py | 23 +++++ .../oneflow/test/modules/test_batch_gather.py | 85 +++++++++++++++++++ 10 files changed, 214 insertions(+), 30 deletions(-) create mode 100644 python/oneflow/test/modules/test_batch_gather.py diff --git a/docs/source/oneflow.rst b/docs/source/oneflow.rst index 247925e5e4e..8fb219f876a 100644 --- a/docs/source/oneflow.rst +++ b/docs/source/oneflow.rst @@ -27,6 +27,8 @@ oneflow atan2, atanh, bernoulli, + broadcast_like, + batch_gather, cat, cast, ceil, diff --git a/oneflow/core/autograd/gradient_funcs/batch_gather.cpp b/oneflow/core/autograd/gradient_funcs/batch_gather.cpp index b21fc693481..cebc13738a7 100644 --- a/oneflow/core/autograd/gradient_funcs/batch_gather.cpp +++ b/oneflow/core/autograd/gradient_funcs/batch_gather.cpp @@ -13,11 +13,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/op_expr_grad_function.h" -#include "oneflow/core/framework/op_builder.h" -#include "oneflow/core/framework/op_expr.h" -#include "oneflow/core/framework/op_expr_helper.h" -#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" +#include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { @@ -34,17 +32,11 @@ class BatchGather : public OpExprGradFunction { const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const BatchGatherCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; - - private: - std::shared_ptr bw_unsorted_batch_segment_sum_op_; }; Maybe BatchGather::Init(const OpExpr& op) { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); - const std::string& op_name = fw_op_expr->op_name(); - bw_unsorted_batch_segment_sum_op_ = - JUST(op_expr_helper::UnsortedBatchSegmentSumOp(/*num_segments=*/1, GradientOpName(op_name))); return Maybe::Ok(); } @@ -64,10 +56,8 @@ Maybe BatchGather::Apply(const BatchGatherCaptureState* ctx, const TensorT in_grads->resize(2); if (!ctx->requires_grad) { return Maybe::Ok(); } const auto& indices = ctx->SavedTensors().at(0); - MutableAttrMap attrs; - JUST(attrs.SetAttr("num_segments", ctx->num_segments)); - in_grads->at(0) = JUST(OpInterpUtil::Dispatch(*bw_unsorted_batch_segment_sum_op_, - {out_grads.at(0), indices}, attrs)); + in_grads->at(0) = + JUST(functional::UnsortedBatchSegmentSum(out_grads.at(0), indices, ctx->num_segments)); return Maybe::Ok(); } diff --git a/oneflow/core/framework/op_expr_helper.cpp b/oneflow/core/framework/op_expr_helper.cpp index cdb0dcfc7bf..5398cf9858a 100644 --- a/oneflow/core/framework/op_expr_helper.cpp +++ b/oneflow/core/framework/op_expr_helper.cpp @@ -380,18 +380,6 @@ Maybe ConcatOp(const int& n, const int64_t& axis, const int64_t .Build(); } -Maybe UnsortedBatchSegmentSumOp(const int& num_segments) { - return UnsortedBatchSegmentSumOp(num_segments, UniqueOpName("unsorted_batch_segment_sum")); -} -Maybe UnsortedBatchSegmentSumOp(const int& num_segments, const std::string& name) { - return one::OpBuilder("unsorted_batch_segment_sum", name) - .Input("data") - .Input("segment_ids") - .Output("out") - .Attr("num_segments", num_segments) - .Build(); -} - Maybe ScalarAddByTensorOp() { return ScalarAddByTensorOp(UniqueOpName("scalar_add_by_tensor")); } diff --git a/oneflow/core/framework/op_expr_helper.h b/oneflow/core/framework/op_expr_helper.h index 47149e64698..3cc708a294a 100644 --- a/oneflow/core/framework/op_expr_helper.h +++ b/oneflow/core/framework/op_expr_helper.h @@ -124,9 +124,6 @@ Maybe ConcatOp(const int& n, const int64_t& axis, const int64_t Maybe ConcatOp(const int& n, const int64_t& axis, const int64_t& max_dim_size, const std::string& name); -Maybe UnsortedBatchSegmentSumOp(const int& num_segments); -Maybe UnsortedBatchSegmentSumOp(const int& num_segments, const std::string& name); - Maybe ScalarAddByTensorOp(); Maybe ScalarAddByTensorOp(const std::string& name); diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index c5ed09d4636..8ef97323e6b 100644 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -1221,3 +1221,18 @@ - name: "recv" signature: "Tensor (Int64 src, Shape shape=None, DataType dtype=None, Device device=None, *, Tensor out=None) => Recv" bind_python: True + +- name: "batch_gather" + signature: + "Tensor (Tensor in, Tensor indices) => BatchGather" + bind_python: True + +- name: "batch_gather" + signature: + "Tensor (Tensor in, Tensor indices) => BatchGather" + bind_python: True + +- name: "unsorted_batch_segment_sum" + signature: + "Tensor (Tensor data, Tensor segment_ids, Int64 num_segments) => UnsortedBatchSegmentSum" + bind_python: False diff --git a/oneflow/core/functional/impl/array_functor.cpp b/oneflow/core/functional/impl/array_functor.cpp index 7219380dd1f..3d3f8644bd4 100644 --- a/oneflow/core/functional/impl/array_functor.cpp +++ b/oneflow/core/functional/impl/array_functor.cpp @@ -1629,6 +1629,42 @@ class SplitWithSizeFunctor { } }; +class BatchGatherFunctor { + public: + BatchGatherFunctor() { + op_ = CHECK_JUST( + one::OpBuilder("batch_gather").Input("in").Input("indices").Output("out").Build()); + } + Maybe operator()(const std::shared_ptr& in, + const std::shared_ptr& indices) const { + return OpInterpUtil::Dispatch(*op_, {in, indices}); + } + + protected: + std::shared_ptr op_; +}; + +class UnsortedBatchSegmentSumFunctor { + public: + UnsortedBatchSegmentSumFunctor() { + op_ = CHECK_JUST(one::OpBuilder("unsorted_batch_segment_sum") + .Input("data") + .Input("segment_ids") + .Output("out") + .Build()); + } + Maybe operator()(const std::shared_ptr& data, + const std::shared_ptr& segment_ids, + const int64_t& num_segments) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr("num_segments", num_segments)); + return OpInterpUtil::Dispatch(*op_, {data, segment_ids}, attrs); + } + + protected: + std::shared_ptr op_; +}; + } // namespace impl ONEFLOW_FUNCTION_LIBRARY(m) { @@ -1708,6 +1744,8 @@ ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("Split"); m.add_functor("SplitLike"); m.add_functor("SplitWithSize"); + m.add_functor("BatchGather"); + m.add_functor("UnsortedBatchSegmentSum"); }; } // namespace functional diff --git a/python/oneflow/__init__.py b/python/oneflow/__init__.py index 03a0de0b3e3..d208a327eb0 100644 --- a/python/oneflow/__init__.py +++ b/python/oneflow/__init__.py @@ -206,7 +206,7 @@ def atexit_hook(hook): del oneflow import oneflow._C -from oneflow._C import tensor +from oneflow._C import tensor, batch_gather from oneflow.autograd import grad_enable, no_grad, inference_mode, is_grad_enabled import oneflow.nn.image diff --git a/python/oneflow/framework/docstr/array_ops.py b/python/oneflow/framework/docstr/array_ops.py index 4ca54ab4d7c..f39b5446fe0 100644 --- a/python/oneflow/framework/docstr/array_ops.py +++ b/python/oneflow/framework/docstr/array_ops.py @@ -104,3 +104,49 @@ """, ) + +add_docstr( + oneflow.batch_gather, + """Gather the element in batch dims. + + Args: + in (Tensor): the input tensor. + indices (Tensor): the indices tensor, its dtype must be int32/64. + + For example: + + Example 1: + + .. code-block:: python + + >>> import oneflow as flow + >>> import numpy as np + + >>> x = flow.Tensor(np.array([[1, 2, 3], + ... [4, 5, 6]])) + >>> indices = flow.tensor(np.array([1, 0]).astype(np.int64)) + >>> out = flow.batch_gather(x, indices) + + tensor([[4., 5., 6.], + [1., 2., 3.]], dtype=oneflow.float32) + + Example 2: + + .. code-block:: python + + >>> import oneflow as flow + >>> import numpy as np + + >>> x = flow.Tensor(np.array([[[1, 2, 3], [4, 5, 6]], + ... [[1, 2, 3], [4, 5, 6]]])) + >>> indices = flow.tensor(np.array([[1, 0], + ... [0, 1]]).astype(np.int64)) + >>> out = flow.batch_gather(x, indices) + + tensor([[[4., 5., 6.], + [1., 2., 3.]], + [[1., 2., 3.], + [4., 5., 6.]]], dtype=oneflow.float32) + + """, +) diff --git a/python/oneflow/nn/modules/broadcast_like.py b/python/oneflow/nn/modules/broadcast_like.py index 5e925852908..c3a3a486613 100644 --- a/python/oneflow/nn/modules/broadcast_like.py +++ b/python/oneflow/nn/modules/broadcast_like.py @@ -48,4 +48,27 @@ def forward(self, x, like_tensor): def broadcast_like_op(x, like_tensor, broadcast_axes: Optional[Sequence] = None): + """This operator broadcast tensor `x` to `like_tensor` according to the broadcast_axes. + + Args: + x (Tensor): The input Tensor. + like_tensor (Tensor): The like Tensor. + broadcast_axes (Optional[Sequence], optional): The axes you want to broadcast. Defaults to None. + + Returns: + [Tensor]: Broadcasted input Tensor. + + For example: + + .. code:: python + + >>> import oneflow as flow + + >>> x = flow.randn(3, 1, 1) + >>> like_tensor = flow.randn(3, 4, 5) + >>> broadcast_tensor = flow.broadcast_like(x, like_tensor, broadcast_axes=[1, 2]) + >>> broadcast_tensor.shape + oneflow.Size([3, 4, 5]) + + """ return BroadCastLike(broadcast_axes=broadcast_axes)(x, like_tensor) diff --git a/python/oneflow/test/modules/test_batch_gather.py b/python/oneflow/test/modules/test_batch_gather.py new file mode 100644 index 00000000000..db1e0688018 --- /dev/null +++ b/python/oneflow/test/modules/test_batch_gather.py @@ -0,0 +1,85 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import unittest +from collections import OrderedDict +import os + +import numpy as np +from test_util import GenArgList + +import oneflow as flow +import oneflow.unittest + + +def _test_batch_gather(test_case, shape, device): + # for example: shape = (3, 2, 2) + x = np.random.randn(*shape) + x_tensor = flow.Tensor(x).to(device) + x_tensor.requires_grad = True + batchsize = x.shape[0] + init_index = np.array( + [np.random.randint(batchsize) for i in range(batchsize)] + ).astype(np.int64) + + batch_gather_index = flow.tensor(init_index).to(device) + batch_gather_out = flow.batch_gather(x_tensor, batch_gather_index) + + x_tensor_gather = flow.Tensor(x).to(device) + x_tensor_gather.requires_grad = True + reshaped_shape = [batchsize] # reshaped_shape = [3] + for i in range(len(x.shape) - 1): + reshaped_shape.append(1) # reshaped_shape = [3] -> [3, 1, 1] + + gather_index = np.reshape(init_index, reshaped_shape) + gather_index = np.broadcast_to(gather_index, shape).astype( + np.int64 + ) # [3, 1, 1] -> [3, 2, 2] + gather_index = flow.tensor(gather_index).to(device) + gather_out = flow.gather(x_tensor_gather, gather_index, dim=0) + total_out = batch_gather_out.sum() + gather_out.sum() + total_out.backward() + + test_case.assertTrue( + np.allclose(batch_gather_out.numpy(), gather_out.numpy(), atol=1e-4, rtol=1e-4) + ) + + test_case.assertTrue( + np.allclose( + x_tensor.grad.numpy(), x_tensor_gather.grad.numpy(), atol=1e-4, rtol=1e-4, + ) + ) + test_case.assertTrue( + np.allclose( + x_tensor.grad.numpy(), x_tensor_gather.grad.numpy(), atol=1e-4, rtol=1e-4, + ) + ) + + +@flow.unittest.skip_unless_1n1d() +class TestBatchGather(flow.unittest.TestCase): + def test_batch_gather(test_case): + arg_dict = OrderedDict() + arg_dict["test_fun"] = [_test_batch_gather] + arg_dict["shape"] = [(3, 2, 2), (3, 2, 4, 2), (3, 3, 4, 2, 2), (4, 2)] + arg_dict["device"] = ["cpu", "cuda"] + + for arg in GenArgList(arg_dict): + arg[0](test_case, *arg[1:]) + + +if __name__ == "__main__": + unittest.main()