Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support inplace operations #5204

Merged
merged 36 commits into from
Jul 7, 2021
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
8795a7b
support inplace forward
poohRui Jun 15, 2021
455e34c
support inplace backward
poohRui Jun 15, 2021
1d6d80c
add test case
poohRui Jun 15, 2021
15390c6
add test case for clone
poohRui Jun 15, 2021
d01d59e
inplace is not support for leaf nodes
poohRui Jun 16, 2021
6dcf506
refine clone
poohRui Jun 16, 2021
7844574
add checks
poohRui Jun 17, 2021
6abb45c
refine
poohRui Jun 17, 2021
530aa41
forbid clone with no grad
poohRui Jun 17, 2021
76605d9
Separate autograd meta to tensor (#5267)
poohRui Jun 23, 2021
4d74ea1
conflict
poohRui Jun 23, 2021
f55d807
inplace without clone
poohRui Jun 24, 2021
f583175
refine
poohRui Jun 24, 2021
77ea879
minor fix
poohRui Jun 24, 2021
dede4f6
remove maybe from constructor
poohRui Jun 25, 2021
25de718
change from create to set
poohRui Jun 28, 2021
45035a7
Merge remote-tracking branch 'origin/master' into dev_support_inplace
wyg1997 Jul 1, 2021
4405fd4
fix merge bugs
wyg1997 Jul 1, 2021
1622323
fix merge bug
wyg1997 Jul 1, 2021
da8e198
remove inplace flag in local_call_opkernel_phy_instr_operand
wyg1997 Jul 1, 2021
48a72e3
Merge remote-tracking branch 'origin/master' into dev_support_inplace
wyg1997 Jul 5, 2021
c09e936
remove out-date codes
wyg1997 Jul 5, 2021
06f9cc3
Merge branch 'master' into dev_support_inplace
wyg1997 Jul 6, 2021
ba40b2f
refine code
wyg1997 Jul 7, 2021
799bbfb
Merge branch 'master' into dev_support_inplace
hjchen2 Jul 7, 2021
e378fa9
Merge branch 'master' into dev_support_inplace
oneflow-ci-bot Jul 7, 2021
6c56553
Merge branch 'master' into dev_support_inplace
oneflow-ci-bot Jul 7, 2021
819c847
Merge branch 'master' into dev_support_inplace
oneflow-ci-bot Jul 7, 2021
c140c13
Merge branch 'master' into dev_support_inplace
oneflow-ci-bot Jul 7, 2021
a410b76
add JUST
wyg1997 Jul 7, 2021
3268162
Merge branch 'master' into dev_support_inplace
oneflow-ci-bot Jul 7, 2021
5dd344c
fix merge master bug
wyg1997 Jul 7, 2021
794c23a
Merge branch 'master' into dev_support_inplace
oneflow-ci-bot Jul 7, 2021
21364ca
revert autograd engine input_grad check
wyg1997 Jul 7, 2021
8462703
Merge branch 'master' into dev_support_inplace
oneflow-ci-bot Jul 7, 2021
adc8211
fix bug in tensor_hook
wyg1997 Jul 7, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion oneflow/api/python/framework/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,9 +247,10 @@ void ExportTensor(py::module& m, const char* name) {
// Methods of pytorch
.def("retain_grad",
[](T& t) {
if (!t.is_leaf()) { t.set_retain_grad(true); }
if (!t.is_leaf()) { t.set_retain_grad(true).GetOrThrow(); }
})
.def("detach", [](const T& t) { return t.api_detach().GetPtrOrThrow(); })
.def("clone", [](const T& t) { return t.api_clone().GetPtrOrThrow(); })
// OneFlow tensor properties other than pytorch tensor
.def_property_readonly("is_lazy", &T::is_lazy)
.def_property_readonly("is_consistent", &T::is_consistent);
Expand Down
14 changes: 8 additions & 6 deletions oneflow/core/autograd/autograd_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,16 @@ StackFunctionNode::StackFunctionNode(
input_meta_datas_.resize(inputs.size());
next_functions_->reserve(inputs.size());
for (int i = 0; i < inputs.size(); ++i) {
input_meta_datas_.at(i) = inputs.at(i)->mut_autograd_meta();
if (input_meta_datas_.at(i)->requires_grad()) {
if (inputs.at(i)->requires_grad()) {
input_meta_datas_.at(i) = inputs.at(i)->mut_autograd_meta();
next_functions_->emplace_back(inputs.at(i)->grad_fn_node());
}
}

output_meta_datas_.resize(outputs.size());
output_tensor_infos_.reserve(outputs.size());
for (int i = 0; i < outputs.size(); ++i) {
outputs.at(i)->create_autograd_meta();
wyg1997 marked this conversation as resolved.
Show resolved Hide resolved
output_meta_datas_.at(i) = outputs.at(i)->mut_autograd_meta();
output_tensor_infos_.emplace_back(TensorInfo(*outputs.at(i)));
}
Expand Down Expand Up @@ -129,6 +130,7 @@ Maybe<bool> StackFunctionNode::Apply(bool create_graph) {
JUST((*backward_fn_)(output_grads, &input_grads, create_graph));
for (int i = 0; i < input_meta_datas_.size(); ++i) {
if (input_grads.at(i)) {
CHECK_NOTNULL_OR_RETURN(input_meta_datas_.at(i));
hjchen2 marked this conversation as resolved.
Show resolved Hide resolved
JUST(input_meta_datas_.at(i)->now_grad_arg()->PushPartialTensor(input_grads.at(i)));
}
}
Expand All @@ -151,7 +153,7 @@ Maybe<void> StackAutogradEngine::RunBackwardAndSaveGrads4LeafTensor(const Tensor
bool create_graph) {
ClearReleasedFunctionNodes();
for (int i = 0; i < outputs.size(); ++i) {
JUST(outputs.at(i)->now_grad_arg()->PushPartialTensor(out_grads.at(i)));
JUST(JUST(outputs.at(i)->now_grad_arg())->PushPartialTensor(out_grads.at(i)));
}
// Runs each FunctionNode
for (const auto& weak_func_node : node_list_) {
Expand Down Expand Up @@ -179,7 +181,7 @@ Maybe<TensorTuple> StackAutogradEngine::RunBackwardAndReturnInputsTensorGrad(
inputs.at(i)->set_retain_grad(true);
}
for (int i = 0; i < outputs.size(); ++i) {
JUST(outputs.at(i)->now_grad_arg()->PushPartialTensor(out_grads.at(i)));
JUST(JUST(outputs.at(i)->now_grad_arg())->PushPartialTensor(out_grads.at(i)));
}
// Runs each FunctionNode
for (const auto& weak_func_node : node_list_) {
Expand All @@ -192,9 +194,9 @@ Maybe<TensorTuple> StackAutogradEngine::RunBackwardAndReturnInputsTensorGrad(
}
}
for (int i = 0; i < inputs.size(); ++i) {
input_now_grads->at(i) = inputs.at(i)->acc_grad();
input_now_grads->at(i) = JUST(inputs.at(i)->acc_grad());
if (!ori_retain_grad.at(i)) {
inputs.at(i)->mut_acc_grad().reset();
JUST(inputs.at(i)->mut_acc_grad()).reset();
inputs.at(i)->set_retain_grad(false);
}
}
Expand Down
11 changes: 9 additions & 2 deletions oneflow/core/eager/local_call_opkernel_phy_instr_operand.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,19 @@ class LocalCallOpKernelPhyInstrOperand final : public vm::PhyInstrOperand {

LocalCallOpKernelPhyInstrOperand(const std::shared_ptr<one::StatefulLocalOpKernel>& opkernel,
const one::EagerBlobObjectListPtr& inputs,
const one::EagerBlobObjectListPtr& outputs, const AttrMap& attrs)
: opkernel_(opkernel), inputs_(inputs), outputs_(outputs), attrs_(attrs) {}
const one::EagerBlobObjectListPtr& outputs, const AttrMap& attrs,
bool is_inplace)
: opkernel_(opkernel),
inputs_(inputs),
outputs_(outputs),
attrs_(attrs),
is_inplace_(is_inplace) {}

const one::StatefulLocalOpKernel& opkernel() const { return *opkernel_; }
const one::EagerBlobObjectListPtr& inputs() const { return inputs_; }
const one::EagerBlobObjectListPtr& outputs() const { return outputs_; }
const AttrMap& attrs() const { return attrs_; }
bool is_inplace() const { return is_inplace_; }

one::StatefulLocalOpKernel* mut_opkernel() { return opkernel_.get(); }

Expand Down Expand Up @@ -86,6 +92,7 @@ class LocalCallOpKernelPhyInstrOperand final : public vm::PhyInstrOperand {
one::EagerBlobObjectListPtr outputs_;
const AttrMap attrs_;
const user_op::OpKernel* user_opkernel_;
bool is_inplace_;
};

} // namespace vm
Expand Down
4 changes: 2 additions & 2 deletions oneflow/core/eager/opkernel_instruction_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ struct LocalCallOpKernelUtil final {
operand->set_user_opkernel(
JUST(operand->mut_opkernel()->ChooseOpKernel(operand->inputs(), operand->outputs())));
JUST(CheckOutputBlobObjectsMemCase(operand, instruction->stream()));
JUST(InitOutputBlobs(operand));
if (!operand->is_inplace()) { JUST(InitOutputBlobs(operand)); }
JUST(InferTempStorageBlobDesc(operand));
JUST(ResetTempStorageBlob(operand));
return Maybe<void>::Ok();
Expand All @@ -456,7 +456,7 @@ struct LocalCallOpKernelUtil final {
static inline Maybe<void> Compute(vm::Instruction* instruction) {
auto* operand = JUST(GetLocalCallOpKernelPhyInstrOperand(instruction));
DeviceCtx* device_ctx = instruction->stream().device_ctx().get();
JUST(AllocateOutputBlobsMemory(operand, device_ctx));
if (!operand->is_inplace()) { JUST(AllocateOutputBlobsMemory(operand, device_ctx)); }
JUST(TryAllocateTempStorageBlobMemory(operand, device_ctx));
user_op::OpKernelState* state;
TryInitOpKernelState(operand, device_ctx, &state);
Expand Down
5 changes: 5 additions & 0 deletions oneflow/core/framework/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@ std::string Device::ToString() const {
return ss.str();
}

std::ostream& operator<<(std::ostream& out, const Device& device) {
out << device.ToString();
return out;
}

Maybe<const Device> Device::MakeDeviceByParallelDesc(const ParallelDesc& parallel_desc) {
std::string type = parallel_desc.device_tag();
if (parallel_desc.device_tag() == "gpu") { type = "cuda"; }
Expand Down
2 changes: 2 additions & 0 deletions oneflow/core/framework/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ class Device final : public std::enable_shared_from_this<Device> {
std::shared_ptr<VmLocalDepObject> compute_local_dep_object_;
};

std::ostream& operator<<(std::ostream& out, const Device& device);

} // namespace oneflow

namespace std {
Expand Down
4 changes: 2 additions & 2 deletions oneflow/core/framework/instructions_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -650,11 +650,11 @@ Maybe<void> InstructionsBuilder::LocalCallOpKernel(
const one::EagerBlobObjectListPtr& input_eager_blob_objects,
const one::EagerBlobObjectListPtr& output_eager_blob_objects, const AttrMap& attrs,
const std::shared_ptr<const ParallelDesc>& parallel_desc_sym,
const std::string& instr_type_name) {
const std::string& instr_type_name, bool is_inplace) {
ObjectMsgPtr<vm::InstructionMsg> instruction =
ObjectMsgPtr<vm::InstructionMsg>::New(instr_type_name);
auto phy_instr_operand = std::make_shared<vm::LocalCallOpKernelPhyInstrOperand>(
opkernel, input_eager_blob_objects, output_eager_blob_objects, attrs);
opkernel, input_eager_blob_objects, output_eager_blob_objects, attrs, is_inplace);
*instruction->mut_parallel_desc() = parallel_desc_sym;
*instruction->mutable_phy_instr_operand() = phy_instr_operand;
instruction_list_->EmplaceBack(std::move(instruction));
Expand Down
2 changes: 1 addition & 1 deletion oneflow/core/framework/instructions_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ class InstructionsBuilder : public std::enable_shared_from_this<InstructionsBuil
const one::EagerBlobObjectListPtr& output_eager_blob_objects,
const AttrMap& attrs,
const std::shared_ptr<const ParallelDesc>& parallel_desc_sym,
const std::string& instr_type_name);
const std::string& instr_type_name, bool is_inplace);

private:
Maybe<void> RankFrontSeqCallback(const std::string& instruction_name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,59 +55,79 @@ Maybe<void> NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& in
TensorTuple* outputs, const AttrMap& attrs) {
std::shared_ptr<EagerBlobObjectList> input_eager_blob_objects =
std::make_shared<EagerBlobObjectList>(inputs.size());
for (int i = 0; i < inputs.size(); i++) {
for (int i = 0; i < inputs.size(); ++i) {
const auto& input_device = JUST(inputs.at(i)->device());
if (i > 0) {
CHECK_OR_RETURN(*default_device == *input_device) << Error::InputDeviceNotMatchError();
}
input_eager_blob_objects->at(i) = JUST(inputs.at(i)->eager_blob_object());
}
for (int i = 0; i < outputs->size(); i++) {
std::shared_ptr<EagerBlobObjectList> output_eager_blob_objects =
std::make_shared<EagerBlobObjectList>(outputs->size());
for (int i = 0; i < outputs->size(); ++i) {
if (!outputs->at(i)) {
outputs->at(i) =
std::make_shared<MirroredTensor>(std::make_shared<EagerMirroredTensorImpl>());
}
if (JUST(outputs->at(i)->has_eager_blob_object())) {
output_eager_blob_objects->at(i) = JUST(outputs->at(i)->eager_blob_object());
}
}
std::shared_ptr<EagerBlobObjectList> output_eager_blob_objects =
std::make_shared<EagerBlobObjectList>(outputs->size());
std::shared_ptr<const Device> op_device;
std::shared_ptr<const ParallelDesc> op_parallel_desc;
bool need_check_mem_case = true;
bool need_event_record = false;

// Infer devices
if (!user_op_expr.has_device_infer_fn()) {
bool is_inplace =
wyg1997 marked this conversation as resolved.
Show resolved Hide resolved
std::all_of(output_eager_blob_objects->begin(), output_eager_blob_objects->end(),
[](const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object) {
return eager_blob_object != nullptr;
});
if (is_inplace) {
for (int i = 0; i < outputs->size(); ++i) {
CHECK_EQ_OR_RETURN(*JUST(outputs->at(i)->device()), *JUST(inputs.at(i)->device()));
output_eager_blob_objects->at(i) = JUST(outputs->at(i)->eager_blob_object());
CHECK_EQ_OR_RETURN(output_eager_blob_objects->at(i), input_eager_blob_objects->at(i));
CHECK_EQ_OR_RETURN(output_eager_blob_objects->at(i)->blob_desc().shape(),
input_eager_blob_objects->at(i)->blob_desc().shape());
}
op_device = default_device;
op_parallel_desc = op_device->parallel_desc_ptr();
for (int i = 0; i < outputs->size(); i++) {
auto* tensor_impl = JUST(TensorImpl4Tensor(outputs->at(i)));
*tensor_impl->mut_device() = default_device;
}
} else {
need_check_mem_case = false;
op_device = JUST(user_op_expr.InferDevices(attrs, inputs, outputs));
for (const auto& input_tensor : inputs) {
const auto& input_device = JUST(input_tensor->device());
need_event_record = need_event_record || !(*op_device == *input_device);
// Infer devices
if (!user_op_expr.has_device_infer_fn()) {
op_device = default_device;
op_parallel_desc = op_device->parallel_desc_ptr();
for (int i = 0; i < outputs->size(); ++i) {
auto* tensor_impl = JUST(TensorImpl4Tensor(outputs->at(i)));
*tensor_impl->mut_device() = default_device;
}
} else {
need_check_mem_case = false;
op_device = JUST(user_op_expr.InferDevices(attrs, inputs, outputs));
for (const auto& input_tensor : inputs) {
const auto& input_device = JUST(input_tensor->device());
need_event_record = need_event_record || !(*op_device == *input_device);
}
op_device = default_device;
op_parallel_desc = op_device->parallel_desc_ptr();
}
op_parallel_desc = op_device->parallel_desc_ptr();
}

// Infer shapes and dtypes
const auto& device_tag = JUST(op_device->of_type());
JUST(user_op_expr.InferLogicalShapeAndDType(
attrs, device_tag,
[&](int32_t i) -> const TensorMeta* {
return CHECK_JUST(TensorImpl4Tensor(inputs.at(i)))->tensor_meta().get();
},
[&](int32_t i) -> TensorMeta* {
return CHECK_JUST(TensorImpl4Tensor(outputs->at(i)))->mut_tensor_meta();
}));

for (int i = 0; i < output_eager_blob_objects->size(); i++) {
auto* tensor_impl = JUST(TensorImpl4Tensor(outputs->at(i)));
JUST(tensor_impl->InitEagerBlobObject(JUST(outputs->at(i)->device())->mem_case()));
output_eager_blob_objects->at(i) = JUST(tensor_impl->eager_blob_object());
// Infer shapes and dtypes
const auto& device_tag = JUST(op_device->of_type());
JUST(user_op_expr.InferLogicalShapeAndDType(
attrs, device_tag,
[&](int32_t i) -> const TensorMeta* {
return CHECK_JUST(TensorImpl4Tensor(inputs.at(i)))->tensor_meta().get();
},
[&](int32_t i) -> TensorMeta* {
return CHECK_JUST(TensorImpl4Tensor(outputs->at(i)))->mut_tensor_meta();
}));

for (int i = 0; i < output_eager_blob_objects->size(); ++i) {
auto* tensor_impl = JUST(TensorImpl4Tensor(outputs->at(i)));
JUST(tensor_impl->InitEagerBlobObject(JUST(outputs->at(i)->device())->mem_case()));
output_eager_blob_objects->at(i) = JUST(tensor_impl->eager_blob_object());
}
}

const auto kernel = JUST(user_op_expr.MutKernel4Device(*op_device));
Expand All @@ -130,7 +150,7 @@ Maybe<void> NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& in
}
}
return builder->LocalCallOpKernel(kernel, input_eager_blob_objects, output_eager_blob_objects,
attrs, op_parallel_desc, instr_type_name);
attrs, op_parallel_desc, instr_type_name, is_inplace);
}));
return Maybe<void>::Ok();
}
Expand Down
18 changes: 18 additions & 0 deletions oneflow/core/framework/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ limitations under the License.
#include "oneflow/core/framework/tensor_tuple.h"
#include "oneflow/core/autograd/autograd_engine.h"
#include "oneflow/core/framework/op_interpreter/eager_mirrored_op_interpreter.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_expr.h"

namespace oneflow {

Expand Down Expand Up @@ -62,6 +65,21 @@ Maybe<MirroredTensor> MirroredTensor::api_detach() const {
return std::make_shared<MirroredTensor>(JUST(impl_->detach()));
}

Maybe<Tensor> MirroredTensor::clone() const {
const auto& device_type = JUST(this->device())->type();
int64_t device_id = JUST(this->device())->device_id();
std::shared_ptr<OpExpr> copy_op_ = JUST(one::OpBuilder("copy")
poohRui marked this conversation as resolved.
Show resolved Hide resolved
hjchen2 marked this conversation as resolved.
Show resolved Hide resolved
.Input("in", 1)
.Attr("device_type", device_type)
.Attr("device_id", device_id)
.Output("out", 1)
.Build());
wyg1997 marked this conversation as resolved.
Show resolved Hide resolved
std::shared_ptr<MirroredTensor> input =
std::const_pointer_cast<MirroredTensor>(shared_from_this());
const auto& output = JUST(OpInterpUtil::Dispatch<Tensor>(*copy_op_, {input}));
return output;
}

Maybe<ConsistentTensor> ConsistentTensor::MakeTensor(
const std::shared_ptr<const Shape>& shape, DataType dtype,
Symbol<cfg::ParallelDistribution> parallel_distribution, Symbol<ParallelDesc> parallel_desc,
Expand Down
Loading