diff --git a/oneflow/api/python/framework/tensor.cpp b/oneflow/api/python/framework/tensor.cpp index 97a83ae9f59..a9c191af02b 100644 --- a/oneflow/api/python/framework/tensor.cpp +++ b/oneflow/api/python/framework/tensor.cpp @@ -38,42 +38,31 @@ namespace one { namespace { -template -const DType* GetTensorDType(const T& tensor) { +const DType* GetTensorDType(const Tensor& tensor) { return DType::Get(tensor.dtype()).GetOrThrow().get(); } -template -struct TensorExportUtil final {}; - -template<> -struct TensorExportUtil final { - static std::shared_ptr MakeTensor(const std::shared_ptr& shape, - const DType* dtype, - const Symbol& device, bool is_lazy, - bool requires_grad, bool is_leaf) { - return MirroredTensor::MakeTensor(shape, dtype->data_type(), device, is_lazy, requires_grad, - is_leaf) - .GetPtrOrThrow(); - } -}; - -template<> -struct TensorExportUtil final { - static std::shared_ptr MakeTensor( - const std::shared_ptr& shape, const DType* dtype, - const std::shared_ptr& parallel_distribution, - const std::shared_ptr& parallel_desc, bool is_lazy, bool requires_grad, - bool is_leaf) { - return ConsistentTensor::MakeTensor(shape, dtype->data_type(), SymbolOf(*parallel_distribution), - SymbolOf(*parallel_desc), is_lazy, requires_grad, is_leaf) - .GetPtrOrThrow(); - } -}; +std::shared_ptr MakeLocalTensor(const std::shared_ptr& shape, + const DType* dtype, const Symbol& device, + bool is_lazy, bool requires_grad, bool is_leaf) { + return MirroredTensor::MakeTensor(shape, dtype->data_type(), device, is_lazy, requires_grad, + is_leaf) + .GetPtrOrThrow(); +} -namespace { +std::shared_ptr MakeConsistentTensor( + const std::shared_ptr& shape, const DType* dtype, + Symbol& parallel_distribution, Symbol parallel_desc, + bool is_lazy, bool requires_grad, bool is_leaf) { + return ConsistentTensor::MakeTensor(shape, dtype->data_type(), parallel_distribution, + parallel_desc, is_lazy, requires_grad, is_leaf) + .GetPtrOrThrow(); +} -Maybe EagerMirroredTensorZeros(const std::shared_ptr& tensor) { +Maybe EagerMirroredTensorZeros(const std::shared_ptr& t) { + const auto& tensor = std::dynamic_pointer_cast(t); + CHECK_NOTNULL_OR_RETURN(tensor) << "local tensors supported only"; + CHECK_OR_RETURN(tensor->is_eager()) << "eager tensors supported only"; JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe { JUST(builder->AccessBlobByCallback( tensor, @@ -84,19 +73,21 @@ Maybe EagerMirroredTensorZeros(const std::shared_ptr& tens "mut")); return Maybe::Ok(); })); - return Maybe::Ok(); } -void ApiEagerMirroredTensorZeros(const std::shared_ptr& tensor) { +void ApiEagerMirroredTensorZeros(const std::shared_ptr& tensor) { return EagerMirroredTensorZeros(tensor).GetOrThrow(); } template -Maybe CopyBetweenMirroredTensorAndNumpy(const std::shared_ptr& tensor, +Maybe CopyBetweenMirroredTensorAndNumpy(const std::shared_ptr& t, py::array_t array, void (*Copy)(uint64_t, py::array_t), const std::string& modifier) { + const auto& tensor = std::dynamic_pointer_cast(t); + CHECK_NOTNULL_OR_RETURN(tensor) << "local tensors supported only"; + CHECK_OR_RETURN(tensor->is_eager()) << "eager tensors supported only"; std::atomic synced(false); JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe { @@ -118,15 +109,13 @@ Maybe CopyBetweenMirroredTensorAndNumpy(const std::shared_ptr -void ApiCopyMirroredTensorToNumpy(const std::shared_ptr& tensor, - py::array_t array) { +void ApiCopyMirroredTensorToNumpy(const std::shared_ptr& tensor, py::array_t array) { return CopyBetweenMirroredTensorAndNumpy(tensor, array, OfBlob_CopyToBuffer, "const") .GetOrThrow(); } template -void ApiCopyMirroredTensorFromNumpy(const std::shared_ptr& tensor, - py::array_t array) { +void ApiCopyMirroredTensorFromNumpy(const std::shared_ptr& tensor, py::array_t array) { return CopyBetweenMirroredTensorAndNumpy(tensor, array, OfBlob_CopyFromBuffer, "mut") .GetOrThrow(); } @@ -161,21 +150,22 @@ const std::string& ApiGetCopyMirroredTensorFromNumpyFuncName(const Tensor& tenso return *GetCopyMirroredTensorFromNumpyFuncName(tensor.dtype()).GetPtrOrThrow(); } -Symbol TensorGetDevice(const MirroredTensor& tensor) { - return tensor.device().GetOrThrow(); -} +Symbol TensorGetDevice(const Tensor& tensor) { return tensor.device().GetOrThrow(); } -std::shared_ptr TensorGetParallelDesc(const ConsistentTensor& tensor) { - return tensor.parallel_desc().GetOrThrow().shared_from_symbol(); +Symbol TensorGetParallelDesc(const Tensor& tensor) { + return tensor.parallel_desc().GetOrThrow(); } -std::tuple, std::vector> GetTensorBufferShapesAndDTypes( - const std::shared_ptr& tensor) { +Maybe, std::vector>> +MaybeGetTensorBufferShapesAndDTypes(const std::shared_ptr& t) { + const auto& tensor = std::dynamic_pointer_cast(t); + CHECK_NOTNULL_OR_RETURN(tensor) << "local tensors supported only"; + CHECK_OR_RETURN(tensor->is_eager()) << "eager tensors supported only"; std::vector shapes; std::vector dtypes; std::atomic synced(false); - CHECK_JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe { + JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe { JUST(builder->AccessBlobByCallback( tensor, [&synced](uint64_t of_blob_ptr) { synced = true; }, "const")); return Maybe::Ok(); @@ -185,7 +175,7 @@ std::tuple, std::vector> GetTensorBufferShapesA while (!synced) {} }); - const Blob& blob = CHECK_JUST(tensor->eager_blob_object())->blob(); + const Blob& blob = JUST(tensor->eager_blob_object())->blob(); const Shape& blob_shape = blob.static_shape(); const auto* tensor_buffer_ptr = blob.dptr(); for (int64_t i = 0; i < blob_shape.elem_cnt(); ++i) { @@ -193,63 +183,56 @@ std::tuple, std::vector> GetTensorBufferShapesA shapes.push_back(tensor_buffer->shape()); dtypes.push_back(DType::Get(tensor_buffer->data_type()).GetOrThrow().get()); } - return std::make_tuple(shapes, dtypes); } -} // namespace - -void SpecializedDef(py::class_>* api) { - using T = MirroredTensor; - api->def_property_readonly("device", &TensorGetDevice); - api->def_property_readonly("data", &T::data); - api->def_property_readonly("_tensor_buffer_shapes_and_dtypes", &GetTensorBufferShapesAndDTypes); -#define DEFINE_TENSOR_METHOD(T, type_proto) \ - api->def("_copy_to_numpy_" #T, &ApiCopyMirroredTensorToNumpy); \ - api->def("_copy_from_numpy_" #T, &ApiCopyMirroredTensorFromNumpy); - OF_PP_FOR_EACH_TUPLE(DEFINE_TENSOR_METHOD, POD_DATA_TYPE_SEQ); - -#undef DEFINE_TENSOR_METHOD - api->def("_get_copy_mirrored_tensor_to_numpy_func_name", - &ApiGetCopyMirroredTensorToNumpyFuncName); - api->def("_get_copy_mirrored_tensor_from_numpy_func_name", - &ApiGetCopyMirroredTensorFromNumpyFuncName); - api->def("zeros_", &ApiEagerMirroredTensorZeros); - api->def("_register_hook", - [](const std::shared_ptr& self, const AutogradMeta::Hook& hook) -> void { - if (!self->grad_fn_node()) { CHECK_JUST(AddAccumulateFunctionNode(self)); } - self->mut_autograd_meta()->add_hook(hook); - }); +std::tuple, std::vector> GetTensorBufferShapesAndDTypes( + const std::shared_ptr& tensor) { + return MaybeGetTensorBufferShapesAndDTypes(tensor).GetOrThrow(); } -void SpecializedDef(py::class_>* api) { - api->def_property_readonly("placement", &TensorGetParallelDesc); +Maybe RegisterTensorHook(const std::shared_ptr& self, + const AutogradMeta::Hook& hook) { + if (!self->grad_fn_node()) { JUST(AddAccumulateFunctionNode(self)); } + self->mut_autograd_meta()->add_hook(hook); + return Maybe::Ok(); +} +void ApiRegisterTensorHook(const std::shared_ptr& self, const AutogradMeta::Hook& hook) { + return RegisterTensorHook(self, hook).GetOrThrow(); } -template -void ExportTensor(py::module& m, const char* name) { - py::class_> tensor_api(m, name); - tensor_api - .def(py::init(&TensorExportUtil::MakeTensor)) +} // namespace + +ONEFLOW_API_PYBIND11_MODULE("", m) { + py::class_>(m, "Tensor") + .def(py::init(&MakeLocalTensor)) + .def(py::init(&MakeConsistentTensor)) // Properties of pytorch - .def_property_readonly("shape", &T::shape) - .def_property_readonly("dtype", &GetTensorDType) - .def_property_readonly("is_cuda", &T::is_cuda) - .def_property_readonly("grad", [](const T& t) { return t.api_acc_grad().GetPtrOrThrow(); }) + .def_property_readonly("shape", &Tensor::shape) + .def_property_readonly("dtype", &GetTensorDType) + .def_property_readonly("is_cuda", &Tensor::is_cuda) + .def_property_readonly("grad", + [](const Tensor& t) -> std::shared_ptr { + if (t.has_autograd_meta()) { + return t.acc_grad().GetPtrOrThrow(); + } else { + return std::shared_ptr(); + } + }) // setter of grad .def("set_grad", - [](T& t, const std::shared_ptr& grad) { + [](Tensor& t, const std::shared_ptr& grad) { if (t.is_leaf()) { - t.set_acc_grad(grad); + t.set_acc_grad(grad).GetOrThrow(); } else { throw std::runtime_error("You can only change gradient of leaf tensors."); } }) - .def_property_readonly("grad_fn", &T::grad_fn_node) - .def_property_readonly("is_leaf", &T::is_leaf) + .def_property_readonly("grad_fn", &Tensor::grad_fn_node) + .def_property_readonly("is_leaf", &Tensor::is_leaf) .def_property( - "requires_grad", &T::requires_grad, - [](T& t, bool requires_grad) { + "requires_grad", &Tensor::requires_grad, + [](Tensor& t, bool requires_grad) { if (t.is_leaf()) { t.set_requires_grad(requires_grad); } else { @@ -258,23 +241,32 @@ void ExportTensor(py::module& m, const char* name) { }) // Methods of pytorch .def("retain_grad", - [](T& t) { + [](Tensor& t) { if (!t.is_leaf()) { t.set_retain_grad(true).GetOrThrow(); } }) - .def("detach", [](const T& t) { return t.api_detach().GetPtrOrThrow(); }) - .def("clone", [](const T& t) { return t.api_clone().GetPtrOrThrow(); }) + .def("detach", [](const Tensor& t) { return t.detach().GetPtrOrThrow(); }) + .def("clone", [](const Tensor& t) { return t.clone().GetPtrOrThrow(); }) // OneFlow tensor properties other than pytorch tensor - .def_property_readonly("is_lazy", &T::is_lazy) - .def_property_readonly("is_consistent", &T::is_consistent); - SpecializedDef(&tensor_api); -} - -} // namespace - -ONEFLOW_API_PYBIND11_MODULE("", m) { - py::class_>(m, "Tensor"); - ExportTensor(m, "LocalTensor"); - ExportTensor(m, "ConsistentTensor"); + .def_property_readonly("is_lazy", &Tensor::is_lazy) + .def_property_readonly("is_eager", &Tensor::is_eager) + .def_property_readonly("is_consistent", &Tensor::is_consistent) + .def_property_readonly("is_local", &Tensor::is_local) + .def("zeros_", &ApiEagerMirroredTensorZeros) + .def("_register_hook", &ApiRegisterTensorHook) + // local tensor only + .def_property_readonly("_tensor_buffer_shapes_and_dtypes", &GetTensorBufferShapesAndDTypes) + .def_property_readonly("device", &TensorGetDevice) + .def_property_readonly("data", &Tensor::data) +#define DEFINE_TENSOR_METHOD(T, type_proto) \ + .def("_copy_to_numpy_" #T, &ApiCopyMirroredTensorToNumpy) \ + .def("_copy_from_numpy_" #T, &ApiCopyMirroredTensorFromNumpy) + OF_PP_FOR_EACH_TUPLE(DEFINE_TENSOR_METHOD, POD_DATA_TYPE_SEQ) +#undef DEFINE_TENSOR_METHOD + .def("_get_copy_mirrored_tensor_to_numpy_func_name", &ApiGetCopyMirroredTensorToNumpyFuncName) + .def("_get_copy_mirrored_tensor_from_numpy_func_name", + &ApiGetCopyMirroredTensorFromNumpyFuncName) + // consistent tensor only + .def_property_readonly("placement", &TensorGetParallelDesc); } } // namespace one diff --git a/oneflow/core/framework/tensor.cpp b/oneflow/core/framework/tensor.cpp index 4fccd92de1c..920486f8313 100644 --- a/oneflow/core/framework/tensor.cpp +++ b/oneflow/core/framework/tensor.cpp @@ -67,13 +67,14 @@ int64_t MirroredTensor::dim(int64_t index) const { return shape()->At(index); } int64_t MirroredTensor::nelement() const { return shape()->elem_cnt(); } -std::shared_ptr MirroredTensor::data() const { +std::shared_ptr MirroredTensor::data() const { std::shared_ptr t = std::make_shared(impl_); return t; } -Maybe MirroredTensor::api_detach() const { - return std::make_shared(JUST(impl_->detach())); +Maybe MirroredTensor::detach() const { + std::shared_ptr tensor = std::make_shared(JUST(impl_->detach())); + return tensor; } Maybe MirroredTensor::clone() const { @@ -117,13 +118,13 @@ int64_t ConsistentTensor::nelement() const { return shape()->elem_cnt(); } int64_t ConsistentTensor::ndim() const { return shape()->NumAxes(); } -std::shared_ptr ConsistentTensor::data() const { +std::shared_ptr ConsistentTensor::data() const { std::shared_ptr t = std::make_shared(impl_); return t; } -Maybe ConsistentTensor::api_detach() const { - std::shared_ptr t = std::make_shared(impl_); +Maybe ConsistentTensor::detach() const { + std::shared_ptr t = std::make_shared(impl_); return t; } diff --git a/oneflow/core/framework/tensor.h b/oneflow/core/framework/tensor.h index fa3f744ee73..b81a32107cb 100644 --- a/oneflow/core/framework/tensor.h +++ b/oneflow/core/framework/tensor.h @@ -48,9 +48,13 @@ class Tensor { virtual Maybe> parallel_distribution() const = 0; virtual Maybe> parallel_desc() const = 0; virtual Maybe> device() const = 0; - virtual Maybe*> mut_device() { OF_UNIMPLEMENTED(); } + virtual Maybe*> mut_device() = 0; + virtual int64_t ndim() const = 0; + virtual bool is_cuda() const = 0; virtual bool is_consistent() const = 0; + virtual bool is_local() const { return !is_consistent(); } virtual bool is_lazy() const = 0; + virtual bool is_eager() const { return !is_lazy(); } virtual const TensorMeta& tensor_meta() const = 0; virtual Maybe> consistent_tensor_meta() const { OF_UNIMPLEMENTED(); } @@ -81,6 +85,7 @@ class Tensor { virtual Maybe now_grad_arg() const = 0; virtual Maybe detach() const = 0; virtual Maybe clone() const = 0; + virtual std::shared_ptr data() const = 0; // Setters for autograd virtual void set_requires_grad(bool requires_grad) = 0; @@ -106,8 +111,6 @@ class TensorIf : public Tensor { virtual ~TensorIf() = default; // Getters - virtual int64_t ndim() const = 0; - virtual bool is_cuda() const = 0; virtual int64_t nelement() const = 0; virtual int64_t dim(int64_t index) const = 0; @@ -115,15 +118,6 @@ class TensorIf : public Tensor { // acc_grad is tensor's accumulated grad in more than once backward operation, // and now_grad_arg is temporary grad to shared data with different FunctionNode std::shared_ptr grad_fn_node() const override { return grad_fn_node_; } - // used by pybind11 only - Maybe api_acc_grad() const { - if (has_autograd_meta()) { - const std::shared_ptr& tensor = JUST(acc_grad()); - return cast_for_api(tensor); - } else { - return std::shared_ptr(); - } - } // Setters for autograd void set_grad_fn_node(const std::shared_ptr& grad_fn_node) override { @@ -131,29 +125,9 @@ class TensorIf : public Tensor { } const std::shared_ptr& mut_grad_fn_node() override { return grad_fn_node_; } - Maybe detach() const override { - return std::static_pointer_cast(JUST(api_detach())); - } - - // Operators for tensor - // used by pybind11 only - virtual Maybe api_detach() const = 0; - Maybe api_clone() const { - const std::shared_ptr& tensor = JUST(clone()); - return cast_for_api(tensor); - } - protected: TensorIf() = default; std::shared_ptr grad_fn_node_; - - private: - Maybe cast_for_api(const std::shared_ptr& tensor) const { - if (!tensor) { return std::shared_ptr(); } - const auto& ptr = std::dynamic_pointer_cast(tensor); - CHECK_OR_RETURN(ptr) << Error::ValueError("Tensor Cast Error"); - return ptr; - } }; class MirroredTensor final : public TensorIf, @@ -179,7 +153,7 @@ class MirroredTensor final : public TensorIf, bool is_cuda() const override; int64_t dim(int64_t index) const override; int64_t nelement() const override; - std::shared_ptr data() const; + std::shared_ptr data() const override; const TensorMeta& tensor_meta() const override { return *impl_->tensor_meta(); } // Getters valid only for EagerMirroredTensor @@ -216,7 +190,7 @@ class MirroredTensor final : public TensorIf, } // Operators for tensor - Maybe api_detach() const override; + Maybe detach() const override; Maybe clone() const override; static Maybe MakeTensor(const std::shared_ptr& shape, DataType dtype, @@ -251,6 +225,7 @@ class ConsistentTensor final : public TensorIf { } Maybe> parallel_desc() const override { return impl_->parallel_desc(); } Maybe> device() const override { OF_UNIMPLEMENTED(); } + Maybe*> mut_device() override { OF_UNIMPLEMENTED(); } bool is_lazy() const override { return impl_->is_lazy(); } bool is_consistent() const override { return true; } Maybe> consumer_parallel_distribution_constraint() @@ -264,7 +239,7 @@ class ConsistentTensor final : public TensorIf { bool is_cuda() const override; int64_t dim(int64_t index) const override; int64_t nelement() const override; - std::shared_ptr data() const; + std::shared_ptr data() const override; // Getters valid only for EagerMirroredTensor Maybe eager_blob_object() const override { @@ -308,7 +283,7 @@ class ConsistentTensor final : public TensorIf { } // Operators for tensor - virtual Maybe api_detach() const override; + Maybe detach() const override; Maybe clone() const override { return Error::Unimplemented(); } static Maybe MakeTensor(const std::shared_ptr& shape, diff --git a/oneflow/python/framework/tensor.py b/oneflow/python/framework/tensor.py index 604809852fb..12307e7c03f 100644 --- a/oneflow/python/framework/tensor.py +++ b/oneflow/python/framework/tensor.py @@ -39,7 +39,7 @@ def decorator(method): op_name = method.__name__ else: op_name = name - setattr(oneflow._oneflow_internal.LocalTensor, op_name, method) + setattr(oneflow._oneflow_internal.Tensor, op_name, method) return method return decorator @@ -868,7 +868,7 @@ def _default_initializer_for_determining(tensor): else: shape = undetermined_tensor.shape dtype = undetermined_tensor.dtype - determined_tensor = oneflow._oneflow_internal.LocalTensor( + determined_tensor = oneflow._oneflow_internal.Tensor( shape, dtype, undetermined_tensor.device, @@ -891,7 +891,7 @@ def _numpy_initializer_for_determining(tensor): if undetermined_tensor.is_consistent: raise NotImplementedError() else: - determined_tensor = oneflow._oneflow_internal.LocalTensor( + determined_tensor = oneflow._oneflow_internal.Tensor( undetermined_tensor.shape, undetermined_tensor.dtype, undetermined_tensor.device, @@ -913,13 +913,7 @@ def _input_args_is_numpy(*args): def _input_args_is_consistent_or_local(*args): - return len(args) == 1 and isinstance( - args[0], - ( - oneflow._oneflow_internal.ConsistentTensor, - oneflow._oneflow_internal.LocalTensor, - ), - ) + return len(args) == 1 and isinstance(args[0], oneflow._oneflow_internal.Tensor) def _input_args_is_tensor(*args): @@ -937,7 +931,7 @@ def _input_args_is_shape(*args): def register_tensor_op(op_name): def set_tensor_op(method): setattr(Tensor, op_name, method) - setattr(oneflow._oneflow_internal.LocalTensor, op_name, method) + setattr(oneflow._oneflow_internal.Tensor, op_name, method) return method return set_tensor_op diff --git a/oneflow/python/framework/tensor_tuple_util.py b/oneflow/python/framework/tensor_tuple_util.py index 75408c04c86..a1cee4d66eb 100644 --- a/oneflow/python/framework/tensor_tuple_util.py +++ b/oneflow/python/framework/tensor_tuple_util.py @@ -17,17 +17,17 @@ import collections from typing import Union, Sequence, Tuple, Optional -from oneflow.python.framework.tensor import Tensor -from oneflow._oneflow_internal import TensorTuple, LocalTensor +from oneflow.python.framework.tensor import Tensor as PyTensor +from oneflow._oneflow_internal import TensorTuple, Tensor def convert_to_tensor_tuple( - args: Optional[Union[Tensor, Sequence[Tensor], LocalTensor, Sequence[LocalTensor]]] + args: Optional[Union[PyTensor, Sequence[PyTensor], Tensor, Sequence[Tensor]]] ): if args is None: return TensorTuple() elif isinstance(args, collections.abc.Sequence): - if isinstance(args[0], Tensor): + if isinstance(args[0], PyTensor): for tensor in args: if not tensor.is_determined: tensor.determine() @@ -35,7 +35,7 @@ def convert_to_tensor_tuple( return TensorTuple(args) else: tensor_tuple = TensorTuple() - if isinstance(args, Tensor): + if isinstance(args, PyTensor): if not args.is_determined: args.determine() tensor_tuple.append(args._local_or_consistent_tensor) diff --git a/oneflow/python/nn/modules/eq.py b/oneflow/python/nn/modules/eq.py index 7faefb3b672..458a9fb913d 100644 --- a/oneflow/python/nn/modules/eq.py +++ b/oneflow/python/nn/modules/eq.py @@ -25,7 +25,7 @@ def __init__(self) -> None: def forward(self, input, other): if isinstance(other, flow.Tensor) or isinstance( - other, flow._oneflow_internal.LocalTensor + other, flow._oneflow_internal.Tensor ): for i in range(len(input.size())): assert ( diff --git a/oneflow/python/nn/modules/ne.py b/oneflow/python/nn/modules/ne.py index 08345281602..74ad5552a66 100644 --- a/oneflow/python/nn/modules/ne.py +++ b/oneflow/python/nn/modules/ne.py @@ -25,7 +25,7 @@ def __init__(self) -> None: def forward(self, input, other): if isinstance(other, flow.Tensor) or isinstance( - other, flow._oneflow_internal.LocalTensor + other, flow._oneflow_internal.Tensor ): for i in range(len(input.size())): assert (