Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Debug dim scatter #5371

Merged
merged 109 commits into from
Jul 31, 2021
Merged
Show file tree
Hide file tree
Changes from 98 commits
Commits
Show all changes
109 commits
Select commit Hold shift + click to select a range
e946449
startup of dev scatter ops
doombeaker Nov 19, 2020
bf505f8
use dim scatter base class
doombeaker Nov 20, 2020
a6d28b9
refine(using binop to abstract scatter update and add
doombeaker Nov 20, 2020
bb05c9b
refine (use macros to implement kerenl class and functors)
doombeaker Nov 20, 2020
cc4ad5a
refine(description for register scatter ops/kernels)
doombeaker Nov 20, 2020
645a19e
refine
doombeaker Nov 20, 2020
2b0b146
add inplace ops
doombeaker Nov 21, 2020
9081a0a
python wraper scatter_add inplace
doombeaker Nov 23, 2020
24fe455
dev inplace ops
doombeaker Nov 24, 2020
60b1fe9
refine dim_gather (using macros register mechanism)
doombeaker Nov 24, 2020
2bdca8b
Merge branch 'master' into dev_dim_scatter
doombeaker Nov 26, 2020
99be5ad
add grad of scatter_add_like
doombeaker Nov 26, 2020
2e883fa
Merge branch 'master' into dev_dim_scatter
doombeaker Dec 15, 2020
4f89f18
refine (add src, like versions for scatter)
doombeaker Dec 15, 2020
35c7457
refine src/like tensor
doombeaker Dec 15, 2020
bd2520c
gather refine(no need outplace/inplace versions)
doombeaker Dec 16, 2020
c5515a0
reformat
doombeaker Dec 16, 2020
e184293
refine
doombeaker Dec 16, 2020
c529461
test case of dim scatter
doombeaker Dec 28, 2020
02fa7ee
test case for dim_scatter_add_like
doombeaker Dec 28, 2020
29b6774
1n2d test case for dim_scatter_add_like
doombeaker Dec 28, 2020
57dca96
refine scatter sbp
doombeaker Dec 28, 2020
005c97e
fail to sccater_add_like on 1n2d
doombeaker Dec 28, 2020
bbb99bc
refing sbp
doombeaker Dec 29, 2020
982bfb8
refine test case, unify add and update like ops
doombeaker Dec 29, 2020
f712efb
test case for scatter_add/update like ops finished
doombeaker Dec 29, 2020
0c64679
test cases for scatter ops
doombeaker Dec 29, 2020
d4da91b
refine, merge test class
doombeaker Dec 29, 2020
56e309b
startup of api docs
doombeaker Dec 29, 2020
6401f4d
add scatter api docs and assertion in python
doombeaker Dec 30, 2020
afe1872
Merge branch 'master' into debug_dim_scatter
MARD1NO Jul 2, 2021
547adde
fix make error but still segment fault
MARD1NO Jul 2, 2021
e99396f
annotate sbp infer
MARD1NO Jul 2, 2021
332501e
rewrite scatter kernel logic
MARD1NO Jul 5, 2021
5915d12
remove inplace proposal and fix macro name
MARD1NO Jul 5, 2021
327944f
remove outdated atomic add
MARD1NO Jul 6, 2021
ae61462
move sbp infer
MARD1NO Jul 6, 2021
cde37c0
add const and throw error
MARD1NO Jul 6, 2021
98e0f70
add check
MARD1NO Jul 7, 2021
472040f
set grad op
MARD1NO Jul 7, 2021
d781ff1
add scatter scalar
MARD1NO Jul 8, 2021
0fea30b
add scatter scalar gpu kernel
MARD1NO Jul 8, 2021
3610e45
add torch style backprop
MARD1NO Jul 8, 2021
22e16e7
add torch style backprop check
MARD1NO Jul 8, 2021
e555948
align with master
MARD1NO Jul 8, 2021
f75a5ba
remove redundant sbp check
MARD1NO Jul 8, 2021
4a10851
add test
MARD1NO Jul 12, 2021
b95fff6
add float16n register
MARD1NO Jul 12, 2021
7427322
fix sbp
MARD1NO Jul 12, 2021
88f7a87
fix sbp
MARD1NO Jul 12, 2021
679c4cd
add api doc
MARD1NO Jul 12, 2021
1032c5f
make format
MARD1NO Jul 12, 2021
f0da7ab
add new line
MARD1NO Jul 12, 2021
af131aa
Merge branch 'master' into debug_dim_scatter
doombeaker Jul 16, 2021
e1322b7
refine
doombeaker Jul 16, 2021
24f82e8
revert dim gather
doombeaker Jul 16, 2021
01acd08
extract dim_scatter_add
doombeaker Jul 17, 2021
cf40ffc
extracat scatter update ops
doombeaker Jul 17, 2021
d4cda6d
add add/update functor
doombeaker Jul 17, 2021
e4a56ad
rewrting by functors
doombeaker Jul 18, 2021
1e2d724
refine
doombeaker Jul 18, 2021
e569e40
remove dim_gather_scatter_uitl.h
doombeaker Jul 18, 2021
2bcf541
add blank line
doombeaker Jul 18, 2021
919eb6b
refine macros for registering kerenls
doombeaker Jul 18, 2021
3a39636
refine dim_scatter_scalar files name
doombeaker Jul 18, 2021
77d2c9e
refine
doombeaker Jul 18, 2021
4ee8b54
refine register ops
doombeaker Jul 18, 2021
2bcd31e
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
doombeaker Jul 18, 2021
bb2d66f
refine
doombeaker Jul 18, 2021
d873ef0
add F.dim_scatter_scalar
doombeaker Jul 18, 2021
33c5361
add scatter op
doombeaker Jul 18, 2021
b42c795
refine docstr
doombeaker Jul 18, 2021
5667dd4
add scatter reduce arg
doombeaker Jul 18, 2021
d7b9d44
finally(!): a draft for scatter constitent with pytroch
doombeaker Jul 18, 2021
a484afa
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
doombeaker Jul 19, 2021
c95ef38
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
doombeaker Jul 19, 2021
11f8735
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
doombeaker Jul 27, 2021
e68a99a
Merge branch 'master' into debug_dim_scatter
doombeaker Jul 27, 2021
dd105c8
change import package name
doombeaker Jul 27, 2021
620e1b2
remmove lazy test and add scatter_add and scatter_mul
doombeaker Jul 27, 2021
7bfd84f
startup of scatter backward op
doombeaker Jul 27, 2021
ee5c49c
add backward for scatter
doombeaker Jul 28, 2021
15fe104
scatter ops backward finished
doombeaker Jul 28, 2021
beac512
add scatter, scatter_add test cases
doombeaker Jul 28, 2021
206907f
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
doombeaker Jul 28, 2021
6113eef
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
doombeaker Jul 28, 2021
e6ed34a
remove useless scatter_update_like
doombeaker Jul 28, 2021
6236405
reformat
doombeaker Jul 28, 2021
5f1c45d
refine test cases
doombeaker Jul 28, 2021
83a74ba
Merge branch 'master' into debug_dim_scatter
doombeaker Jul 29, 2021
bad64a9
refine according to comments
doombeaker Jul 29, 2021
b9581c3
Merge branch 'debug_dim_scatter' of https://github.com/Oneflow-Inc/on…
doombeaker Jul 29, 2021
c875bd1
revert op_exprt_helper
doombeaker Jul 29, 2021
8d7710b
Merge branch 'master' into debug_dim_scatter
doombeaker Jul 29, 2021
a4deca8
Merge branch 'master' into debug_dim_scatter
doombeaker Jul 29, 2021
a33840a
Merge branch 'master' into debug_dim_scatter
doombeaker Jul 30, 2021
bf787c6
fixed index element
MARD1NO Jul 30, 2021
69b81dc
Merge branch 'master' into debug_dim_scatter
doombeaker Jul 30, 2021
7a96b95
Merge branch 'master' into debug_dim_scatter
oneflow-ci-bot Jul 30, 2021
e5bc436
Merge branch 'master' into debug_dim_scatter
oneflow-ci-bot Jul 30, 2021
1bb8e83
Merge branch 'master' into debug_dim_scatter
oneflow-ci-bot Jul 30, 2021
1f70313
Merge branch 'master' into debug_dim_scatter
oneflow-ci-bot Jul 30, 2021
8f02b0a
Merge branch 'master' into debug_dim_scatter
oneflow-ci-bot Jul 30, 2021
14ddf9e
Merge branch 'master' into debug_dim_scatter
oneflow-ci-bot Jul 30, 2021
577fc6c
Merge branch 'master' into debug_dim_scatter
oneflow-ci-bot Jul 30, 2021
fbd657b
Merge branch 'master' into debug_dim_scatter
oneflow-ci-bot Jul 30, 2021
fa051ae
fix scatter update like expr for dim gather backward
doombeaker Jul 31, 2021
58bc172
Merge branch 'debug_dim_scatter' of https://github.com/Oneflow-Inc/on…
doombeaker Jul 31, 2021
3a4acb1
Merge branch 'master' into debug_dim_scatter
doombeaker Jul 31, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/oneflow.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ oneflow
reshape,
save,
saved_model,
scatter,
scatter_add,
scatter_nd,
selu,
silu,
Expand Down
176 changes: 176 additions & 0 deletions oneflow/core/autograd/gradient_funcs/dim_scatter.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/op_expr_helper.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"

namespace oneflow {
namespace one {

struct DimScatterInterpState : public OpExprInterpState {
int32_t dim;
bool input_requires_grad;
bool src_requires_grad;
};

enum SCATTER_TYPE { SCATTER_UPDATE, SCATTER_ADD };

template<SCATTER_TYPE T>
class DimScatter : public OpExprGradFunction<DimScatterInterpState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(DimScatterInterpState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const DimScatterInterpState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
Maybe<void> ApplyCommon(const DimScatterInterpState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const;

private:
AttrMap base_attrs_;
};

template<SCATTER_TYPE T>
Maybe<void> DimScatter<T>::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr);
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}

template<SCATTER_TYPE T>
Maybe<void> DimScatter<T>::Capture(DimScatterInterpState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
CHECK_EQ_OR_RETURN(inputs.size(), 3);
CHECK_EQ_OR_RETURN(outputs.size(), 1);

ctx->input_requires_grad = inputs.at(0)->requires_grad();
ctx->src_requires_grad = inputs.at(2)->requires_grad();
if ((!ctx->input_requires_grad) && (!ctx->src_requires_grad)) { return Maybe<void>::Ok(); }

ctx->SaveTensorForBackward(inputs.at(1)); // index saved

ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->dim = JUST(composed_attrs.GetAttr<int32_t>("dim"));
return Maybe<void>::Ok();
}

template<SCATTER_TYPE T>
Maybe<void> DimScatter<T>::ApplyCommon(const DimScatterInterpState* ctx,
const TensorTuple& out_grads, TensorTuple* in_grads) const {
const std::shared_ptr<oneflow::one::Tensor>& index = ctx->SavedTensors().at(0);

in_grads->resize(3);

if (ctx->src_requires_grad) {
in_grads->at(2) = JUST(functional::DimGather(out_grads.at(0), index, ctx->dim));
}
return Maybe<void>::Ok();
}

template<>
Maybe<void> DimScatter<SCATTER_TYPE::SCATTER_UPDATE>::Apply(const DimScatterInterpState* ctx,
const TensorTuple& out_grads,
TensorTuple* in_grads) const {
if ((!ctx->input_requires_grad) && (!ctx->src_requires_grad)) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
JUST(ApplyCommon(ctx, out_grads, in_grads));

if (ctx->input_requires_grad) {
const std::shared_ptr<oneflow::one::Tensor>& index = ctx->SavedTensors().at(0);
in_grads->at(0) =
JUST(functional::DimScatterUpdateScalar(out_grads.at(0), index, 0.0f, ctx->dim));
}
return Maybe<void>::Ok();
}

template<>
Maybe<void> DimScatter<SCATTER_TYPE::SCATTER_ADD>::Apply(const DimScatterInterpState* ctx,
const TensorTuple& out_grads,
TensorTuple* in_grads) const {
if ((!ctx->input_requires_grad) && (!ctx->src_requires_grad)) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1);

JUST(ApplyCommon(ctx, out_grads, in_grads));

if (ctx->input_requires_grad) { in_grads->at(0) = out_grads.at(0); }

return Maybe<void>::Ok();
}

class DimScatterUpdateScalar : public OpExprGradFunction<DimScatterInterpState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(DimScatterInterpState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const DimScatterInterpState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;

private:
AttrMap base_attrs_;
};

Maybe<void> DimScatterUpdateScalar::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr);
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());

return Maybe<void>::Ok();
}

Maybe<void> DimScatterUpdateScalar::Capture(DimScatterInterpState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs,
const AttrMap& attrs) const {
CHECK_EQ_OR_RETURN(inputs.size(), 2);
CHECK_EQ_OR_RETURN(outputs.size(), 1);

ctx->input_requires_grad = inputs.at(0)->requires_grad();
if (!ctx->input_requires_grad) { return Maybe<void>::Ok(); }

ctx->SaveTensorForBackward(inputs.at(1)); // index saved

ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->dim = JUST(composed_attrs.GetAttr<int32_t>("dim"));
return Maybe<void>::Ok();
}

Maybe<void> DimScatterUpdateScalar::Apply(const DimScatterInterpState* ctx,
const TensorTuple& out_grads,
TensorTuple* in_grads) const {
if (!ctx->input_requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
const std::shared_ptr<oneflow::one::Tensor>& index = ctx->SavedTensors().at(0);

in_grads->resize(2);

MutableAttrMap attrs;
JUST(attrs.SetAttr<int32_t>("dim", ctx->dim));
JUST(attrs.SetAttr<float>("src_scalar", 0.0f));
in_grads->at(0) =
JUST(functional::DimScatterUpdateScalar(out_grads.at(0), index, 0.0f, ctx->dim););

return Maybe<void>::Ok();
}

REGISTER_OP_EXPR_GRAD_FUNCTION("dim_scatter_update", DimScatter<SCATTER_TYPE::SCATTER_UPDATE>);
REGISTER_OP_EXPR_GRAD_FUNCTION("dim_scatter_add", DimScatter<SCATTER_TYPE::SCATTER_ADD>);
REGISTER_OP_EXPR_GRAD_FUNCTION("dim_scatter_update_scalar", DimScatterUpdateScalar);

} // namespace one
} // namespace oneflow
16 changes: 16 additions & 0 deletions oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,22 @@
signature: "Tensor TensorGetItem(Tensor x, *, TensorIndex index)"
bind_python: True

- name: "dim_scatter"
signature: "Tensor DimScatter(Tensor input, Tensor index, Tensor src, *, Int32 dim)"
bind_python: True

- name: "dim_scatter_add"
signature: "Tensor DimScatterAdd(Tensor input, Tensor index, Tensor src, *, Int32 dim)"
bind_python: True

- name: "dim_scatter_scalar"
signature: "Tensor DimScatterUpdateScalar(Tensor input, Tensor index, *, Float src, Int32 dim)"
bind_python: True

- name: "dim_scatter_add_scalar"
signature: "Tensor DimScatterAddScalar(Tensor input, Tensor index, *, Float src, Int32 dim)"
bind_python: True

- name: "tensor_setitem"
signature: "Void TensorSetItem(Tensor x, *, TensorIndex index, Tensor value)"
bind_python: True
Expand Down
138 changes: 138 additions & 0 deletions oneflow/core/functional/impl/array_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,138 @@ class DimGatherFunctor {
std::shared_ptr<OpExpr> op_;
};

class DimScatterFunctor {
public:
DimScatterFunctor() {
op_ = CHECK_JUST(one::OpBuilder("dim_scatter_update")
.Input("input")
.Input("index")
.Input("src")
.Output("output")
.Build());
}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,
const std::shared_ptr<one::Tensor>& index,
const std::shared_ptr<one::Tensor>& src, const int32_t& dim) const {
MutableAttrMap attrs;
JUST(attrs.SetAttr<int32_t>("dim", dim));
return OpInterpUtil::Dispatch<Tensor>(*op_, {input, index, src}, attrs);
}

private:
std::shared_ptr<OpExpr> op_;
};

class DimScatterAddFunctor {
public:
DimScatterAddFunctor() {
op_ = CHECK_JUST(one::OpBuilder("dim_scatter_add")
.Input("input")
.Input("index")
.Input("src")
.Output("output")
.Build());
}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,
const std::shared_ptr<one::Tensor>& index,
const std::shared_ptr<one::Tensor>& src, const int32_t& dim) const {
MutableAttrMap attrs;
JUST(attrs.SetAttr<int32_t>("dim", dim));
return OpInterpUtil::Dispatch<Tensor>(*op_, {input, index, src}, attrs);
}

private:
std::shared_ptr<OpExpr> op_;
};

class DimScatterMulFunctor {
public:
DimScatterMulFunctor() {
op_ = CHECK_JUST(one::OpBuilder("dim_scatter_mul")
.Input("input")
.Input("index")
.Input("src")
.Output("output")
.Build());
}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,
const std::shared_ptr<one::Tensor>& index,
const std::shared_ptr<one::Tensor>& src, const int32_t& dim) const {
MutableAttrMap attrs;
JUST(attrs.SetAttr<int32_t>("dim", dim));
return OpInterpUtil::Dispatch<Tensor>(*op_, {input, index, src}, attrs);
}

private:
std::shared_ptr<OpExpr> op_;
};

class DimScatterUpdateScalarFunctor {
public:
DimScatterUpdateScalarFunctor() {
op_ = CHECK_JUST(one::OpBuilder("dim_scatter_update_scalar")
.Input("input")
.Input("index")
.Output("output")
.Build());
}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,
const std::shared_ptr<one::Tensor>& index, const float& src,
const int32_t& dim) const {
MutableAttrMap attrs;
JUST(attrs.SetAttr<int32_t>("dim", dim));
JUST(attrs.SetAttr<float>("src_scalar", src));
return OpInterpUtil::Dispatch<Tensor>(*op_, {input, index}, attrs);
}

private:
std::shared_ptr<OpExpr> op_;
};

class DimScatterAddScalarFunctor {
public:
DimScatterAddScalarFunctor() {
op_ = CHECK_JUST(one::OpBuilder("dim_scatter_add_scalar")
.Input("input")
.Input("index")
.Output("output")
.Build());
}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,
const std::shared_ptr<one::Tensor>& index, const float& src,
const int32_t& dim) const {
MutableAttrMap attrs;
JUST(attrs.SetAttr<int32_t>("dim", dim));
JUST(attrs.SetAttr<float>("src_scalar", src));
return OpInterpUtil::Dispatch<Tensor>(*op_, {input, index}, attrs);
}

private:
std::shared_ptr<OpExpr> op_;
};

class DimScatterMulScalarFunctor {
public:
DimScatterMulScalarFunctor() {
op_ = CHECK_JUST(one::OpBuilder("dim_scatter_mul_scalar")
.Input("input")
.Input("index")
.Output("output")
.Build());
}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,
const std::shared_ptr<one::Tensor>& index, const float& src,
const int32_t& dim) const {
MutableAttrMap attrs;
JUST(attrs.SetAttr<int32_t>("dim", dim));
JUST(attrs.SetAttr<float>("src_scalar", src));
return OpInterpUtil::Dispatch<Tensor>(*op_, {input, index}, attrs);
}

private:
std::shared_ptr<OpExpr> op_;
};

class GatherNdFunctor {
public:
GatherNdFunctor() {
Expand Down Expand Up @@ -997,6 +1129,12 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::DiagFunctor>("Diag");
m.add_functor<impl::DiagGradFunctor>("DiagGrad");
m.add_functor<impl::TensorGetItemFunctor>("TensorGetItem");
m.add_functor<impl::DimScatterFunctor>("DimScatter");
m.add_functor<impl::DimScatterAddFunctor>("DimScatterAdd");
m.add_functor<impl::DimScatterMulFunctor>("DimScatterMul");
m.add_functor<impl::DimScatterUpdateScalarFunctor>("DimScatterUpdateScalar");
m.add_functor<impl::DimScatterAddScalarFunctor>("DimScatterAddScalar");
m.add_functor<impl::DimScatterMulScalarFunctor>("DimScatterMulScalar");
m.add_functor<impl::TensorSetItemFunctor>("TensorSetItem");
};

Expand Down
13 changes: 0 additions & 13 deletions oneflow/user/kernels/dim_gather_kernel_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,7 @@ struct DimGatherFunctor<DeviceType::kCPU, IN_T, IDX_T> final {
}
};

template<typename IN_T, typename IDX_T>
struct DimScatterAddFunctor<DeviceType::kCPU, IN_T, IDX_T> final {
void operator()(DeviceCtx* ctx, const DimOpIndexNdHelper<IDX_T>& input_nd_helper,
const DimOpIndexNdHelper<IDX_T>& output_nd_helper, int ndim, int64_t elem_cnt,
int32_t dim, const IDX_T* index, const IN_T* input, IN_T* output) {
DoDimScatterAdd<IN_T, IDX_T>(input_nd_helper, output_nd_helper, ndim, elem_cnt, dim, index,
input, output);
}
};

OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_DIM_GATHER_FUNCTOR, (DeviceType::kCPU),
DIM_GATHER_SCATTER_DATA_TYPE_CPU_SEQ, INDEX_DATA_TYPE_SEQ);
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_DIM_SCATTER_ADD_FUNCTOR, (DeviceType::kCPU),
DIM_GATHER_SCATTER_DATA_TYPE_CPU_SEQ, INDEX_DATA_TYPE_SEQ);

} // namespace user_op
} // namespace oneflow
Loading