Skip to content

Commit

Permalink
[Functional] Part6: Migrate conv op (#5252)
Browse files Browse the repository at this point in the history
* Add partial unary and math functional apis.

* Revert elementwise pow.

* auto format by CI

* Support add with large number of inputs.

* Update oneflow/python/nn/modules/math_ops.py

Co-authored-by: Yinggang Wang <wyg19970408@gmail.com>

* Refine

* Migrate binary and activation ops.

* Migrate array ops.

* Add or refactor activation grad funcs.

* Add or refactor activation grad funcs.

* Revert unpack all

* Fix masked fill

* Refine

* Add nn ops.

* Refine

* Refine

* Migrate conv op

* Fix functional normalization.

* auto format by CI

* Refine code style

* align Torch params

* fix bias add error

* Support optional parameter.

* fix group bug

* add new test case

* remove useless state

* Move optional storage into namespace internal.

* add check

Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org>
Co-authored-by: Yinggang Wang <wyg19970408@gmail.com>
Co-authored-by: Luyang <flowingsun007@163.com>
Co-authored-by: MARD1NO <359521840@qq.com>
Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
Co-authored-by: ZZK <42901638+MARD1NO@users.noreply.github.com>
  • Loading branch information
7 people committed Jun 25, 2021
1 parent 695dd7e commit deff67e
Show file tree
Hide file tree
Showing 13 changed files with 677 additions and 158 deletions.
5 changes: 5 additions & 0 deletions oneflow/api/python/functional/python_arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ Maybe<std::shared_ptr<one::Tensor>> PythonArg::ObjectAs<std::shared_ptr<one::Ten
return detail::cast<std::shared_ptr<one::Tensor>>(Borrow());
}

template<>
Maybe<one::Tensor> PythonArg::ObjectAs<one::Tensor>() const {
return *JUST(detail::cast<std::shared_ptr<one::Tensor>>(Borrow()));
}

template<>
Maybe<std::shared_ptr<one::TensorTuple>> PythonArg::ObjectAs<std::shared_ptr<one::TensorTuple>>()
const {
Expand Down
17 changes: 16 additions & 1 deletion oneflow/api/python/functional/python_arg.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,21 @@ class PythonArg {

virtual ~PythonArg() = default;

template<typename T>
friend class ObjectAsHelper;

template<typename T>
struct ObjectAsHelper {
Maybe<T> operator()(const PythonArg* self) { return self->ObjectAs<T>(); }
};
template<typename T>
struct ObjectAsHelper<Optional<T>> {
Maybe<Optional<T>> operator()(const PythonArg* self) {
if (self->object_ == Py_None) { return std::make_shared<Optional<T>>(); }
return std::make_shared<Optional<T>>(JUST(self->ObjectAs<T>()));
}
};

template<typename T>
operator T() const {
if (active_tag_ == HAS_IMMEDIATE) {
Expand All @@ -70,7 +85,7 @@ class PythonArg {
return *reinterpret_cast<const T*>(immediate_->Ptr());
}
CHECK_EQ_OR_THROW(active_tag_, HAS_OBJECT);
return this->ObjectAs<oneflow::detail::remove_cvref_t<T>>().GetOrThrow();
return ObjectAsHelper<oneflow::detail::remove_cvref_t<T>>()(this).GetOrThrow();
}

private:
Expand Down
110 changes: 50 additions & 60 deletions oneflow/core/autograd/gradient_funcs/conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,103 +13,93 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/op_expr_helper.h"
#include "oneflow/core/framework/user_op_conf_trait.h"
#include "oneflow/core/functional/functional.h"

namespace oneflow {
namespace one {

namespace {

struct ConvInterpState : public OpExprInterpState {
bool weight_requires_grad = true;
bool input_requires_grad = true;
struct ConvolutionNdInterpState : public OpExprInterpState {
bool input_requires_grad = false;
bool weight_requires_grad = false;
size_t input_index;
size_t weight_index;

std::string data_format;
std::vector<int32_t> padding_before;
std::vector<int32_t> kernel_size;
std::vector<int32_t> strides;
std::vector<int32_t> dilation_rate;
int32_t groups;
};

class ConvNdGrad : public OpExprGradFunction<ConvInterpState> {
class ConvolutionNd : public OpExprGradFunction<ConvolutionNdInterpState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(ConvInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const override;
Maybe<void> Apply(const ConvInterpState* ctx, const TensorTuple& out_grads,
Maybe<void> Capture(ConvolutionNdInterpState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const ConvolutionNdInterpState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;

private:
std::shared_ptr<user_op::UserOpConfTrait> op_trait_;
std::shared_ptr<std::string> data_format_;
std::shared_ptr<std::vector<int32_t>> padding_before_;
std::shared_ptr<std::vector<int32_t>> kernel_size_;
std::shared_ptr<std::vector<int32_t>> strides_;
std::shared_ptr<std::vector<int32_t>> dilation_rate_;
int32_t groups_;

std::shared_ptr<OpExpr> data_grad_op_;
std::shared_ptr<OpExpr> weight_grad_op_;
AttrMap base_attrs_;
};

Maybe<void> ConvNdGrad::Init(const OpExpr& op) {
Maybe<void> ConvolutionNd::Init(const OpExpr& op) {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr);
const std::string& op_name = fw_op_expr->op_name();
op_trait_ = std::make_shared<user_op::UserOpConfTrait>(op_name, fw_op_expr->proto());

data_format_ = JUST(op_trait_->GetAttr<std::string>("data_format"));
padding_before_ = JUST(op_trait_->GetAttr<std::vector<int32_t>>("padding_before"));
kernel_size_ = JUST(op_trait_->GetAttr<std::vector<int32_t>>("kernel_size"));
strides_ = JUST(op_trait_->GetAttr<std::vector<int32_t>>("strides"));
dilation_rate_ = JUST(op_trait_->GetAttr<std::vector<int32_t>>("dilation_rate"));
groups_ = JUST(op_trait_->GetAttr<int32_t>("groups"));
int32_t ndims = kernel_size_->size();
CHECK_EQ_OR_RETURN(ndims, strides_->size());
CHECK_EQ_OR_RETURN(ndims, dilation_rate_->size());
data_grad_op_ = JUST(op_expr_helper::ConvNdDataGradOp(*kernel_size_, *strides_, *padding_before_,
*dilation_rate_, groups_, *data_format_));

weight_grad_op_ = JUST(op_expr_helper::ConvNdFilterGradOp(
*kernel_size_, *strides_, *padding_before_, *dilation_rate_, groups_, *data_format_));

base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}

Maybe<void> ConvNdGrad::Capture(ConvInterpState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
Maybe<void> ConvolutionNd::Capture(ConvolutionNdInterpState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
CHECK_EQ_OR_RETURN(inputs.size(), 2);
ctx->input_requires_grad = inputs.at(0)->requires_grad();
ctx->weight_requires_grad = inputs.at(1)->requires_grad();
ctx->SaveTensorForBackward(inputs.at(0)); // x
if (!ctx->input_requires_grad && !ctx->weight_requires_grad) { return Maybe<void>::Ok(); }
if (ctx->input_requires_grad) {
ctx->SaveTensorForBackward(inputs.at(1)); // weight
ctx->weight_index = ctx->SaveTensorForBackward(inputs.at(1)); // weight
}
ctx->input_index = ctx->SaveTensorForBackward(inputs.at(0)); // input

ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->data_format = JUST(composed_attrs.GetAttr<std::string>("data_format"));
ctx->padding_before = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("padding_before"));
ctx->kernel_size = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("kernel_size"));
ctx->strides = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("strides"));
ctx->dilation_rate = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("dilation_rate"));
ctx->groups = JUST(composed_attrs.GetAttr<int32_t>("groups"));
return Maybe<void>::Ok();
}

Maybe<void> ConvNdGrad::Apply(const ConvInterpState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
const auto& dy = out_grads.at(0);

Maybe<void> ConvolutionNd::Apply(const ConvolutionNdInterpState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
in_grads->resize(2);
size_t num_spatial_dims = ctx->kernel_size.size();
if (ctx->input_requires_grad) {
const auto& x = ctx->SavedTensors().at(0);
const auto& weight = ctx->SavedTensors().at(1);
in_grads->at(0) =
JUST(OpInterpUtil::Dispatch<Tensor>(*data_grad_op_, {dy, weight, x}, /*attrs=*/{}));
const auto& weight = ctx->SavedTensors().at(ctx->weight_index);
const auto& input = ctx->SavedTensors().at(ctx->input_index);
in_grads->at(0) = JUST(functional::ConvDataGrad(
out_grads.at(0), weight, input, num_spatial_dims, ctx->kernel_size, ctx->strides,
ctx->padding_before, ctx->dilation_rate, ctx->groups, ctx->data_format));
}
if (ctx->weight_requires_grad) {
const auto& x = ctx->SavedTensors().at(0);
in_grads->at(1) = JUST(OpInterpUtil::Dispatch<Tensor>(*weight_grad_op_, {dy, x}, /*attrs=*/{}));
const auto& input = ctx->SavedTensors().at(ctx->input_index);
in_grads->at(1) = JUST(functional::ConvFilterGrad(
out_grads.at(0), input, num_spatial_dims, ctx->kernel_size, ctx->strides,
ctx->padding_before, ctx->dilation_rate, ctx->groups, ctx->data_format));
}
return Maybe<void>::Ok();
}

} // namespace

REGISTER_OP_EXPR_GRAD_FUNCTION("conv1d", ConvNdGrad);
REGISTER_OP_EXPR_GRAD_FUNCTION("conv2d", ConvNdGrad);
REGISTER_OP_EXPR_GRAD_FUNCTION("conv3d", ConvNdGrad);
REGISTER_OP_EXPR_GRAD_FUNCTION("conv1d", ConvolutionNd);
REGISTER_OP_EXPR_GRAD_FUNCTION("conv2d", ConvolutionNd);
REGISTER_OP_EXPR_GRAD_FUNCTION("conv3d", ConvolutionNd);

} // namespace one
} // namespace oneflow
202 changes: 202 additions & 0 deletions oneflow/core/common/optional.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

#ifndef ONEFLOW_CORE_COMMON_OPTIONAL_H_
#define ONEFLOW_CORE_COMMON_OPTIONAL_H_

#include "oneflow/core/common/type_traits.h"
#include "oneflow/core/common/maybe.h"

namespace oneflow {

namespace internal {

template<typename T, typename U = void>
class Storage;

template<typename T>
class Storage<T, typename std::enable_if<IsScalarType<T>::value>::type> {
public:
Storage() = default;

template<typename... Args,
typename std::enable_if<std::is_constructible<T, Args...>::value, int>::type = 0>
Storage(Args&&... args) {
new (&value_) T(std::forward<Args>(args)...);
}

Storage& operator=(const T& value) {
value_ = value;
return *this;
}
Storage& operator=(T&& value) {
value_ = std::move(value);
return *this;
}
Storage& operator=(const Storage<T>& rhs) {
value_ = rhs.value_;
return *this;
}
Storage& operator=(Storage<T>&& rhs) {
value_ = std::move(rhs.value_);
return *this;
}

Maybe<T> value() const { return value_; }

private:
T value_;
};

template<typename T>
class Storage<T, typename std::enable_if<!IsScalarType<T>::value>::type> {
public:
Storage() = default;

template<typename... Args,
typename std::enable_if<std::is_constructible<T, Args...>::value, int>::type = 0>
Storage(Args&&... args) {
value_ = std::make_shared<T>(std::forward<Args>(args)...);
}

Storage(const std::shared_ptr<T>& value) : value_(value) {}

Storage& operator=(const T& value) {
if (value_) {
*value_ = value;
} else {
value_ = std::make_shared<T>(value);
}
return *this;
}
Storage& operator=(T&& value) {
if (value_) {
*value_ = std::move(value);
} else {
value_ = std::make_shared<T>(value);
}
return *this;
}
Storage& operator=(const Storage<T>& rhs) {
value_ = rhs.value_;
return *this;
}
Storage& operator=(Storage<T>&& rhs) {
value_ = std::move(rhs.value_);
return *this;
}

Maybe<T> value() const { return value_; }

private:
std::shared_ptr<T> value_;
};

} // namespace internal

template<typename T>
class Optional {
public:
Optional() : init_(false) {}

template<typename... Args,
typename std::enable_if<std::is_constructible<internal::Storage<T>, Args...>::value,
int>::type = 0>
Optional(Args&&... args) : init_(true), storage_(std::forward<Args>(args)...) {}

~Optional() = default;

Optional(const Optional<T>& rhs) : init_(rhs.init_) {
if (init_) { storage_ = rhs.storage_; }
}

Optional(Optional<T>&& rhs) : init_(rhs.init_) {
if (init_) { storage_ = std::move(rhs.storage_); }
}

Optional& operator=(const T& val) {
init_ = true;
storage_ = val;
return *this;
}

Optional& operator=(T&& val) {
init_ = true;
storage_ = std::move(val);
return *this;
}

Optional& operator=(const Optional<T>& rhs) {
init_ = rhs.init_;
if (init_) { storage_ = rhs.storage_; }
return *this;
}

Optional& operator=(Optional<T>&& rhs) {
init_ = rhs.init_;
if (init_) { storage_ = std::move(rhs.storage_); }
return *this;
}

Maybe<T> value() const {
CHECK_OR_RETURN(has_value()) << "Optional has no value.";
return storage_.value();
}

bool has_value() const { return init_; }
operator bool() const { return has_value(); }

private:
bool init_;
internal::Storage<T> storage_;
};

template<typename T>
class Optional<T&> {
public:
Optional() : value_ptr_(nullptr) {}

Optional(T& val) : value_ptr_(&val) {}

~Optional() = default;

Optional& operator=(T& val) {
value_ptr_ = &val;
return *this;
}

Optional& operator=(const Optional<T&>& rhs) {
value_ptr_ = rhs.value_ptr_;
return *this;
}

Maybe<T&> value() const {
CHECK_OR_RETURN(has_value()) << "Optional has no value.";
return *value_ptr_;
}

void Clear() { value_ptr_ = nullptr; }

bool has_value() const { return value_ptr_ != nullptr; }
operator bool() const { return has_value(); }

private:
T* value_ptr_;
};

} // namespace oneflow

#endif // ONEFLOW_CORE_COMMON_OPTIONAL_H_
Loading

0 comments on commit deff67e

Please sign in to comment.