Skip to content

Commit

Permalink
Feat empty op (#5659)
Browse files Browse the repository at this point in the history
* fix bugs in shareing EagerBlobObject::blob_desc_.shape and EagerBlobObject::blob_.shape

* feat(EmptyOp): add flow.empty

* docs(EmptyOp): add doctest and refine document

* docs(EmptyOp): refine document

* refactor(Tensor): Tensor constructor use empty_op

* refactor(Tensor): remove useless code

* feat(EmptyOp): support construct in given device and add
               consistent_empty op

* feat(EmptyOp): support unpacked tuple shape

* refine array functor code

* docs(EmptyOp): update empty op document

* refine code

* docs(EmptyOp): add test and document for consistent empty op

* update document

* fix merge bugs

* fix(*): fix infer distribution

* test(EmptyOp): fix ConsistentEmptyOp CPU_ONLY test bug

* fix(*): init shape when InitBlob

* fix(*): Constant and Empty Op use broadcast sbp

* fix(indexing): replace MakeTensor with functional::Empty

* fix(*): fix compile bug

* refine code

* fix(nnGraph): make eager tensor

* auto format by CI

* fix(Stride): infer stride before initializing shape

Co-authored-by: Xinqi Li <lixinqi0703106@163.com>
Co-authored-by: Li Xinqi <lixinqi2010@gmail.com>
Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org>
  • Loading branch information
5 people committed Aug 7, 2021
1 parent 578aa5c commit ed9b5a5
Show file tree
Hide file tree
Showing 21 changed files with 437 additions and 96 deletions.
3 changes: 2 additions & 1 deletion docs/source/oneflow.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ oneflow
load,
masked_fill,
matmul,
empty,
mish,
ones,
ones_like,
Expand All @@ -49,4 +50,4 @@ oneflow
zeros,
zeros_like

.. autofunction:: oneflow.data.load_mnist(train_batch_size=100, test_batch_size=100, data_format='NCHW')
.. autofunction:: oneflow.data.load_mnist(train_batch_size=100, test_batch_size=100, data_format='NCHW')
36 changes: 8 additions & 28 deletions oneflow/api/python/framework/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,23 +51,6 @@ const DType* GetTensorDType(const Tensor& tensor) {
return DType::Get(tensor.dtype()).GetOrThrow().get();
}

std::shared_ptr<Tensor> MakeLocalTensor(const std::shared_ptr<const Shape>& shape,
const DType* dtype, const Symbol<Device>& device,
bool is_lazy, bool requires_grad, bool is_leaf) {
return MirroredTensor::MakeTensor(shape, dtype->data_type(), device, is_lazy, requires_grad,
is_leaf)
.GetPtrOrThrow();
}

std::shared_ptr<Tensor> MakeConsistentTensor(
const std::shared_ptr<const Shape>& shape, const DType* dtype,
Symbol<cfg::ParallelDistribution>& parallel_distribution, Symbol<ParallelDesc> parallel_desc,
bool is_lazy, bool requires_grad, bool is_leaf) {
return ConsistentTensor::MakeTensor(shape, dtype->data_type(), parallel_distribution,
parallel_desc, is_lazy, requires_grad, is_leaf)
.GetPtrOrThrow();
}

Maybe<void> EagerMirroredTensorZeros(const std::shared_ptr<Tensor>& t) {
const auto& tensor = JUST(t->AsMirroredTensor());
CHECK_OR_RETURN(tensor->is_eager()) << "eager tensors supported only";
Expand Down Expand Up @@ -180,21 +163,17 @@ Maybe<Tensor> MakeLocalTensorByNumpy(py::object array, const DType* desired_dtyp
auto* np_arr = reinterpret_cast<PyArrayObject*>(np_arr_pyobject);
bool init_from_numpy = py::isinstance<py::array>(array);
const npy_intp* dims_ptr = PyArray_SHAPE(np_arr);
const auto shape = std::make_shared<Shape>(DimVector(dims_ptr, dims_ptr + PyArray_NDIM(np_arr)));
const Shape shape = Shape(DimVector(dims_ptr, dims_ptr + PyArray_NDIM(np_arr)));
DataType flow_dtype = JUST(numpy::GetOFDataTypeFromNpArray(np_arr));
std::shared_ptr<Tensor> tensor =
MirroredTensor::MakeTensor(shape, flow_dtype, device, /* is_lazy */ false, requires_grad,
/* is_leaf */ true)
.GetPtrOrThrow();
std::shared_ptr<Tensor> tensor = JUST(functional::Empty(shape, flow_dtype, device));
JUST(SwitchCopyMirroredTensorFromUntypedArray(SwitchCase(flow_dtype), tensor, np_arr_raii));
if (flow_dtype == DataType::kDouble && !init_from_numpy && desired_dtype == nullptr) {
desired_dtype = DType::Float().get();
}
if (desired_dtype != nullptr) {
autograd::NoGradGuard no_grad;
tensor = JUST(functional::Cast(tensor, desired_dtype->data_type()));
tensor->set_requires_grad(requires_grad);
}
tensor->set_requires_grad(requires_grad);
return tensor;
}

Expand Down Expand Up @@ -318,11 +297,12 @@ Maybe<Tensor> NewTensor(py::args args, py::kwargs kwargs, const DType* desired_d
return Error::ValueError("invalid arg: " + py::str(arg).cast<std::string>());
}
}
const Shape shape = Shape(dim_vector);
CHECK_NOTNULL_OR_RETURN(desired_dtype);
std::shared_ptr<MirroredTensor> tensor = JUST(
MirroredTensor::MakeTensor(std::make_shared<Shape>(dim_vector), desired_dtype->data_type(),
device, /* is_lazy */ false, requires_grad, /* is_leaf */ true));
return std::static_pointer_cast<Tensor>(tensor);
std::shared_ptr<Tensor> tensor =
JUST(functional::Empty(shape, desired_dtype->data_type(), device));
tensor->set_requires_grad(requires_grad);
return tensor;
}

std::shared_ptr<Tensor> ApiNewTensor(py::args args, py::kwargs kwargs) {
Expand Down
7 changes: 3 additions & 4 deletions oneflow/api/python/functional/indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,22 +148,21 @@ Maybe<Tensor> ConvertToIndexingTensor(PyObject* object) {
const DataType dtype = JUST(InferScalarType(object));
const auto& sizes = JUST(InferArraySizes(object));
const auto& device = JUST(Device::New("cpu"));
const auto& tensor = JUST(MirroredTensor::MakeTensor(sizes, dtype, *device, /*is_lazy=*/false,
/*requires_grad=*/false, /*is_leaf=*/true));
const auto& tensor = JUST(functional::Empty(*sizes, dtype, device));
// Prevent the python object release until the callback is complete.
Py_INCREF(object);
auto handle = std::shared_ptr<PyObject>(PyObjectPtr(object));
JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> {
JUST(builder->AccessBlobByCallback(
tensor,
std::dynamic_pointer_cast<MirroredTensor>(tensor),
[handle](uint64_t of_blob_ptr) {
auto* of_blob = reinterpret_cast<OfBlob*>(of_blob_ptr);
CHECK_JUST(ParseArrayToBlob(handle.get(), of_blob->mut_blob()));
},
"mut"));
return Maybe<void>::Ok();
}));
return std::dynamic_pointer_cast<Tensor>(tensor);
return tensor;
}

Maybe<IndexItem> UnpackIndexItem(PyObject* object) {
Expand Down
6 changes: 6 additions & 0 deletions oneflow/core/common/shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ Shape& Shape::operator=(const Shape& shape) {
return *this;
}

Shape& Shape::assign(const DimVector& dim_vec) {
dim_vec_ = dim_vec;
UpdateElemCnt();
return *this;
}

Shape& Shape::CheckNumAxesIdenticalAndAssign(const ShapeView& shape_view) {
CHECK_EQ(NumAxes(), shape_view.NumAxes());
std::copy(shape_view.ptr(), shape_view.ptr() + shape_view.NumAxes(), dim_vec_.data());
Expand Down
1 change: 1 addition & 0 deletions oneflow/core/common/shape.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class Shape final {
Shape(const std::initializer_list<int64_t>& dim_vec);
~Shape() = default;
Shape& operator=(const Shape& shape);
Shape& assign(const DimVector& dim_vec);
Shape& CheckNumAxesIdenticalAndAssign(const ShapeView& shape_view);
Shape& LeftOnesExtendedAssign(const ShapeView& shape_view);

Expand Down
2 changes: 1 addition & 1 deletion oneflow/core/common/shape_view.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ template<typename DimT>
void ShapeViewBase<DimT>::ToShape(Shape* shape) const {
DimVector dim_vec;
this->ToDimVector(&dim_vec);
*shape = Shape(std::move(dim_vec));
shape->assign(dim_vec);
}

template class ShapeViewBase<const int64_t>;
Expand Down
11 changes: 8 additions & 3 deletions oneflow/core/eager/eager_blob_object.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,14 @@ Maybe<void> EagerBlobObject::TryInitBlob() {
Maybe<void> EagerBlobObject::InitBlob() {
CHECK_NE_OR_RETURN(blob_desc_.data_type(), DataType::kInvalidDataType);
if (!blob_desc_.shape().is_initialized()) { blob_desc_.set_shape(Shape(DimVector{})); }
char* header_buffer =
reinterpret_cast<char*>(const_cast<int64_t*>(blob_desc_.shape().dim_vec().data()));
blob_.reset(new Blob(*mem_case_, &blob_desc_, header_buffer, nullptr));
{
header_buffer_.reset();
int64_t header_byte_size = blob_desc_.AlignedByteSizeOfBlobHeader();
const auto& FreeHeader = [header_byte_size](char* dptr) { std::free(dptr); };
char* ptr = reinterpret_cast<char*>(std::malloc(header_byte_size));
header_buffer_ = std::unique_ptr<char, std::function<void(char*)>>(ptr, FreeHeader);
}
blob_.reset(new Blob(*mem_case_, &blob_desc_, header_buffer_.get(), nullptr));
return Maybe<void>::Ok();
}

Expand Down
2 changes: 2 additions & 0 deletions oneflow/core/eager/eager_blob_object.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class EagerBlobObject final : public BlobObject {
~EagerBlobObject() override {
non_pod_initer_.reset();
tensor_buffer_.reset();
header_buffer_.reset();
blob_.reset();
}

Expand Down Expand Up @@ -78,6 +79,7 @@ class EagerBlobObject final : public BlobObject {

private:
std::unique_ptr<Blob> blob_;
std::unique_ptr<char, std::function<void(char*)>> header_buffer_;
std::shared_ptr<TensorBuffer> tensor_buffer_;
std::size_t blob_body_bytes_;
std::unique_ptr<MemoryAllocator> non_pod_initer_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ limitations under the License.
#include "oneflow/core/framework/tensor.h"
#include "oneflow/core/framework/tensor_name_scope.h"
#include "oneflow/core/framework/tensor_tuple.h"
#include "oneflow/core/framework/stride.h"
#include "oneflow/core/framework/op_expr_helper.h"
#include "oneflow/core/eager/foreign_boxing_util.h"
#include "oneflow/core/memory/memory_case_util.h"
Expand Down Expand Up @@ -53,6 +54,30 @@ Maybe<EagerMirroredTensorImpl*> TensorImpl4Tensor(const std::shared_ptr<Tensor>&
return tensor->mut_eager_mirrored_tensor_impl();
}

class MutMirroredTensorMeta : public TensorMeta {
public:
MutMirroredTensorMeta() : TensorMeta(std::make_shared<const Shape>(), kInvalidDataType) {}
MutMirroredTensorMeta(const MutMirroredTensorMeta&) = default;
MutMirroredTensorMeta(MutMirroredTensorMeta&&) = default;
~MutMirroredTensorMeta() override = default;
};

std::vector<TensorMeta*>* ThreadLocalDefaultOutputMutTensorMetas(int64_t size) {
static thread_local std::vector<MutMirroredTensorMeta> struct_vec;
static thread_local std::vector<TensorMeta*> ptr_vec;
struct_vec.resize(size);
ptr_vec.resize(size);
if (size == 1) {
ptr_vec.at(0) = &struct_vec.at(0); // unfold loop
} else if (size == 2) {
ptr_vec.at(0) = &struct_vec.at(0); // unfold loop
ptr_vec.at(1) = &struct_vec.at(1); // unfold loop
} else {
for (int i = 0; i < size; ++i) { ptr_vec.at(i) = &struct_vec.at(i); }
}
return &ptr_vec;
}

} // namespace

Maybe<void> NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& inputs,
Expand All @@ -70,12 +95,15 @@ Maybe<void> NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& in
}
std::shared_ptr<EagerBlobObjectList> output_eager_blob_objects =
std::make_shared<EagerBlobObjectList>(outputs->size());
auto* output_tensor_metas = ThreadLocalDefaultOutputMutTensorMetas(outputs->size());
for (int i = 0; i < outputs->size(); i++) {
if (!outputs->at(i)) {
outputs->at(i) =
std::make_shared<MirroredTensor>(std::make_shared<EagerMirroredTensorImpl>());
}
if (JUST(outputs->at(i)->has_eager_blob_object())) {
const auto& tensor_impl = std::make_shared<EagerMirroredTensorImpl>();
outputs->at(i) = std::make_shared<MirroredTensor>(tensor_impl);
output_tensor_metas->at(i) = tensor_impl->mut_tensor_meta();
} else {
bool has_eager_blob_object = JUST(outputs->at(i)->has_eager_blob_object());
CHECK_OR_RETURN(has_eager_blob_object);
output_eager_blob_objects->at(i) = JUST(outputs->at(i)->eager_blob_object());
}
}
Expand Down Expand Up @@ -110,14 +138,22 @@ Maybe<void> NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& in
return CHECK_JUST(TensorImpl4Tensor(inputs.at(i)))->mut_tensor_meta();
},
[&](int32_t i) -> TensorMeta* {
return CHECK_JUST(TensorImpl4Tensor(outputs->at(i)))->mut_tensor_meta();
// using thread_local TensorMeta pointer if inplace.
// using tensor_impl TensorMeta pointer if not inplace.
return output_tensor_metas->at(i);
}));

for (int i = 0; i < output_eager_blob_objects->size(); i++) {
auto* tensor_impl = JUST(TensorImpl4Tensor(outputs->at(i)));
if (!output_eager_blob_objects->at(i)) {
auto* tensor_impl = JUST(TensorImpl4Tensor(outputs->at(i)));
tensor_impl->mut_tensor_meta()->set_stride(std::make_shared<Stride>(*tensor_impl->shape()));
JUST(tensor_impl->InitEagerBlobObject(JUST(outputs->at(i)->device())->mem_case()));
output_eager_blob_objects->at(i) = JUST(tensor_impl->eager_blob_object());
} else {
// output i is inplaced.
// check thread_local TensorMeta and tensor_impl TensorMeta.
CHECK_OR_RETURN(tensor_impl->tensor_meta()->shape() == output_tensor_metas->at(i)->shape());
CHECK_OR_RETURN(tensor_impl->tensor_meta()->dtype() == output_tensor_metas->at(i)->dtype());
}
}

Expand Down
6 changes: 1 addition & 5 deletions oneflow/core/framework/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,7 @@ Maybe<MirroredTensor> StaticZerosTensor::AsMirroredTensor() {
} else {
const auto& impl =
std::make_shared<EagerMirroredTensorImpl>(tensor_meta, requires_grad, is_leaf);
const auto& tensor = std::make_shared<MirroredTensor>(impl);
const auto& outputs = std::make_shared<TensorTuple>();
outputs->push_back(tensor);
JUST(RunEmptyOp(outputs.get()));
return tensor;
return std::make_shared<MirroredTensor>(impl);
}
}

Expand Down
12 changes: 10 additions & 2 deletions oneflow/core/framework/tensor_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ limitations under the License.
#include "oneflow/core/vm/vm_util.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/control/global_process_ctx.h"
#include "oneflow/core/register/ofblob.h"

namespace oneflow {
namespace one {
Expand Down Expand Up @@ -133,9 +134,16 @@ const std::shared_ptr<const Shape>& EagerMirroredTensorImpl::shape() const {

std::atomic<bool> synced(false);

const auto& shape_ptr = eager_blob_object_->blob_desc().shape_ptr();
CHECK_JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> {
JUST(builder->AccessBlobByCallback(
this, [&synced](uint64_t) { synced = true; }, "const"));
this,
[&synced, &shape_ptr](uint64_t of_blob_ptr) {
const auto* of_blob = reinterpret_cast<OfBlob*>(of_blob_ptr);
of_blob->blob().shape_view().ToShape(const_cast<Shape*>(shape_ptr.get()));
synced = true;
},
"const"));
return Maybe<void>::Ok();
}));

Expand All @@ -145,7 +153,7 @@ const std::shared_ptr<const Shape>& EagerMirroredTensorImpl::shape() const {
});

eager_blob_object_->set_is_shape_synced(true);
return eager_blob_object_->blob_desc().shape_ptr();
return shape_ptr;
}

Maybe<MirroredTensorImpl> EagerMirroredTensorImpl::detach() const {
Expand Down
8 changes: 8 additions & 0 deletions oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,14 @@
signature: "Tensor ConsistentConstant(*, Shape shape, Scalar value, DataType dtype, Placement placement, SbpList sbp_tuple)"
bind_python: True

- name: "empty"
signature: "Tensor Empty(*, Shape shape, DataType dtype, Device device=None)"
bind_python: True

- name: "consistent_empty"
signature: "Tensor ConsistentEmpty(*, Shape shape, DataType dtype, Placement placement, SbpList sbp_tuple)"
bind_python: True

- name: "zeros_like"
signature: "Tensor ZerosLike(Tensor x)"
bind_python: True
Expand Down

0 comments on commit ed9b5a5

Please sign in to comment.