Skip to content

Commit

Permalink
Support grid_sample and affine_grid operator
Browse files Browse the repository at this point in the history
  • Loading branch information
tingkuanpei committed Aug 25, 2021
1 parent ab5ca8f commit 971a85b
Show file tree
Hide file tree
Showing 19 changed files with 3,584 additions and 0 deletions.
2 changes: 2 additions & 0 deletions docs/source/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,5 @@ Functional operations for neural networks
.. autofunction:: one_hot
.. autofunction:: dropout
.. autofunction:: upsample
.. autofunction:: affine_grid
.. autofunction:: grid_sample
70 changes: 70 additions & 0 deletions oneflow/core/autograd/gradient_funcs/affine_grid.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"

namespace oneflow {
namespace one {

struct AffineGridInterpState : public AutoGradCaptureState {
Shape size;
bool align_corners;
bool requires_grad;
};

class AffineGrid : public OpExprGradFunction<AffineGridInterpState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr);
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}

Maybe<void> Capture(AffineGridInterpState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1);
ctx->requires_grad = inputs.at(0)->requires_grad(); // theta
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }

ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->size = JUST(composed_attrs.GetAttr<Shape>("size"));
ctx->align_corners = JUST(composed_attrs.GetAttr<bool>("align_corners"));
return Maybe<void>::Ok();
}

Maybe<void> Apply(const AffineGridInterpState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
in_grads->resize(1);

if (ctx->requires_grad) {
in_grads->at(0) =
JUST(functional::AffineGridGrad(out_grads.at(0), ctx->size, ctx->align_corners));
}
return Maybe<void>::Ok();
}

private:
AttrMap base_attrs_;
};

REGISTER_OP_EXPR_GRAD_FUNCTION("affine_grid", AffineGrid);

} // namespace one
} // namespace oneflow
88 changes: 88 additions & 0 deletions oneflow/core/autograd/gradient_funcs/grid_sample.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"

namespace oneflow {
namespace one {

struct GirdSampleInterpState : public AutoGradCaptureState {
Shape size;
std::string interpolation_mode;
std::string padding_mode;
bool align_corners;
size_t input_index;
size_t gird_index;
bool input_requires_grad;
bool gird_requires_grad;
bool requires_grad;
};

class GirdSample : public OpExprGradFunction<GirdSampleInterpState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr);
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}

Maybe<void> Capture(GirdSampleInterpState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 2);
ctx->input_requires_grad = inputs.at(0)->requires_grad();
ctx->gird_requires_grad = inputs.at(1)->requires_grad();
ctx->requires_grad = ctx->input_requires_grad || ctx->gird_requires_grad;
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }

ctx->input_index = ctx->SaveTensorForBackward(inputs.at(0)); // input
ctx->gird_index = ctx->SaveTensorForBackward(inputs.at(1)); // gird

ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->interpolation_mode = JUST(composed_attrs.GetAttr<std::string>("interpolation_mode"));
ctx->padding_mode = JUST(composed_attrs.GetAttr<std::string>("padding_mode"));
ctx->align_corners = JUST(composed_attrs.GetAttr<bool>("align_corners"));
return Maybe<void>::Ok();
}

Maybe<void> Apply(const GirdSampleInterpState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1);

if (ctx->requires_grad) {
const auto& input = ctx->SavedTensors().at(ctx->input_index);
const auto& gird = ctx->SavedTensors().at(ctx->gird_index);
const auto& results =
JUST(functional::GridSampleGrad(out_grads.at(0), input, gird, ctx->interpolation_mode,
ctx->padding_mode, ctx->align_corners));

in_grads->resize(2);
if (ctx->input_requires_grad) { in_grads->at(0) = results->at(0); }
if (ctx->gird_requires_grad) { in_grads->at(1) = results->at(1); }
}
return Maybe<void>::Ok();
}

private:
AttrMap base_attrs_;
};

REGISTER_OP_EXPR_GRAD_FUNCTION("grid_sample", GirdSample);

} // namespace one
} // namespace oneflow
20 changes: 20 additions & 0 deletions oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,26 @@
Float m1, Float m2, Float m3, Int64 depth)"
bind_python: False

- name: "affine_grid"
signature:
"Tensor AffineGrid(Tensor theta, *, Shape size, Bool align_corners)"
bind_python: True

- name: "affine_grid_grad"
signature:
"Tensor AffineGridGrad(Tensor dgrid, *, Shape size, Bool align_corners)"
bind_python: False

- name: "grid_sample"
signature:
"Tensor GridSample(Tensor input, Tensor grid, *, String interpolation_mode, String padding_mode, Bool align_corners)"
bind_python: True

- name: "grid_sample_grad"
signature:
"TensorTuple GridSampleGrad(Tensor doutput, Tensor input, Tensor grid, *, String interpolation_mode, String padding_mode, Bool align_corners)"
bind_python: False

- name: "where"
signature: "Tensor Where(Tensor condition, Tensor x, Tensor y)"
bind_python: False
Expand Down
40 changes: 40 additions & 0 deletions oneflow/core/functional/impl/nn_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,44 @@ class CombinedMarginLossFunctor {
std::shared_ptr<OpExpr> op_;
};

class AffineGridFunctor {
public:
AffineGridFunctor() {
op_ = CHECK_JUST(one::OpBuilder("affine_grid").Input("theta").Output("grid").Build());
}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& theta, const Shape& size,
const bool& align_corners) const {
MutableAttrMap attrs;
JUST(attrs.SetAttr<Shape>("size", size));
JUST(attrs.SetAttr<bool>("align_corners", align_corners));
return OpInterpUtil::Dispatch<Tensor>(*op_, {theta}, attrs);
}

private:
std::shared_ptr<OpExpr> op_;
};

class GridSampleFunctor {
public:
GridSampleFunctor() {
op_ = CHECK_JUST(
one::OpBuilder("grid_sample").Input("input").Input("grid").Output("output").Build());
}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,
const std::shared_ptr<one::Tensor>& grid,
const std::string& interpolation_mode, const std::string& padding_mode,
const bool& align_corners) const {
MutableAttrMap attrs;
JUST(attrs.SetAttr<std::string>("interpolation_mode", interpolation_mode));
JUST(attrs.SetAttr<std::string>("padding_mode", padding_mode));
JUST(attrs.SetAttr<bool>("align_corners", align_corners));
return OpInterpUtil::Dispatch<Tensor>(*op_, {input, grid}, attrs);
}

private:
std::shared_ptr<OpExpr> op_;
};

class NormalizationFunctor {
public:
NormalizationFunctor() {
Expand Down Expand Up @@ -757,6 +795,8 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::SparseSoftmaxCrossEntropyFunctor>("SparseSoftmaxCrossEntropy");
m.add_functor<impl::SmoothL1LossFunctor>("SmoothL1Loss");
m.add_functor<impl::CombinedMarginLossFunctor>("CombinedMarginLoss");
m.add_functor<impl::AffineGridFunctor>("AffineGrid");
m.add_functor<impl::GridSampleFunctor>("GridSample");
m.add_functor<impl::NormalizationFunctor>("Normalization");
m.add_functor<impl::PadFunctor>("Pad");
m.add_functor<impl::DropoutFunctor>("Dropout");
Expand Down
46 changes: 46 additions & 0 deletions oneflow/core/functional/impl/nn_grad_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,50 @@ class CombinedMarginLossGradFunctor {
std::shared_ptr<OpExpr> op_;
};

class AffineGridGradFunctor {
public:
AffineGridGradFunctor() {
op_ = CHECK_JUST(one::OpBuilder("affine_grid_grad").Input("dgrid").Output("dtheta").Build());
}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dgrid, const Shape& size,
const bool& align_corners) const {
MutableAttrMap attrs;
JUST(attrs.SetAttr<Shape>("size", size));
JUST(attrs.SetAttr<bool>("align_corners", align_corners));
return OpInterpUtil::Dispatch<one::Tensor>(*op_, {dgrid}, attrs);
}

private:
std::shared_ptr<OpExpr> op_;
};

class GridSampleGradFunctor {
public:
GridSampleGradFunctor() {
op_ = CHECK_JUST(one::OpBuilder("grid_sample_grad")
.Input("doutput")
.Input("input")
.Input("grid")
.Output("dinput")
.Output("dgrid")
.Build());
}
Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& doutput,
const std::shared_ptr<one::Tensor>& input,
const std::shared_ptr<one::Tensor>& grid,
const std::string& interpolation_mode,
const std::string& padding_mode, const bool& align_corners) const {
MutableAttrMap attrs;
JUST(attrs.SetAttr<std::string>("interpolation_mode", interpolation_mode));
JUST(attrs.SetAttr<std::string>("padding_mode", padding_mode));
JUST(attrs.SetAttr<bool>("align_corners", align_corners));
return OpInterpUtil::Dispatch<one::TensorTuple>(*op_, {doutput, input, grid}, attrs);
}

private:
std::shared_ptr<OpExpr> op_;
};

class PadGradFunctor {
public:
PadGradFunctor() {
Expand Down Expand Up @@ -412,6 +456,8 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::AdaptivePoolNdGradFunctor>("AdaptivePoolNdGrad");
m.add_functor<impl::SmoothL1LossGradFunctor>("SmoothL1LossGrad");
m.add_functor<impl::CombinedMarginLossGradFunctor>("CombinedMarginLossGrad");
m.add_functor<impl::AffineGridGradFunctor>("AffineGridGrad");
m.add_functor<impl::GridSampleGradFunctor>("GridSampleGrad");
m.add_functor<impl::PoolingNdGradFunctor>("PoolingNdGrad");
m.add_functor<impl::PadGradFunctor>("PadGrad");
m.add_functor<impl::AvgPoolingNdGradFunctor>("AvgPoolingNdGrad");
Expand Down

0 comments on commit 971a85b

Please sign in to comment.