Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

export Tensor only to python #5440

Merged
merged 24 commits into from
Jul 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
4316d77
export Tensor only to python
lixinqi Jul 9, 2021
ba963eb
Merge branch 'master' into refactor_tensor_api
lixinqi Jul 9, 2021
25159e8
Merge branch 'master' into refactor_tensor_api
lixinqi Jul 9, 2021
55b317b
Merge branch 'refactor_tensor_api' of github.com:Oneflow-Inc/oneflow …
lixinqi Jul 9, 2021
d24c582
address review comments
lixinqi Jul 9, 2021
7ed1d3d
Merge branch 'master' into refactor_tensor_api
lixinqi Jul 9, 2021
7f6d806
address review comments
lixinqi Jul 9, 2021
2d76231
Merge branch 'refactor_tensor_api' of github.com:Oneflow-Inc/oneflow …
lixinqi Jul 9, 2021
4f84945
Merge branch 'master' into refactor_tensor_api
oneflow-ci-bot Jul 9, 2021
0a82927
Merge branch 'master' into refactor_tensor_api
oneflow-ci-bot Jul 9, 2021
76504e0
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
clackhan Jul 10, 2021
0aeef2b
Merge branch 'master' into refactor_tensor_api
oneflow-ci-bot Jul 10, 2021
2de39a0
Merge branch 'master' into refactor_tensor_api
oneflow-ci-bot Jul 10, 2021
6515c69
Merge branch 'master' into refactor_tensor_api
oneflow-ci-bot Jul 10, 2021
612d168
refine
clackhan Jul 10, 2021
0a15cfc
Merge branch 'refactor_tensor_api' of https://github.com/Oneflow-Inc/…
clackhan Jul 10, 2021
647a583
Merge branch 'master' into refactor_tensor_api
oneflow-ci-bot Jul 10, 2021
0d2aaa6
Merge branch 'master' into refactor_tensor_api
oneflow-ci-bot Jul 10, 2021
8b05a19
Update tensor_tuple_util.py
clackhan Jul 10, 2021
72a308b
Update tensor.cpp
clackhan Jul 10, 2021
a804094
Merge branch 'master' into refactor_tensor_api
oneflow-ci-bot Jul 10, 2021
0da9e81
Merge branch 'master' into refactor_tensor_api
oneflow-ci-bot Jul 10, 2021
085f692
Update tensor.cpp
clackhan Jul 11, 2021
4646b20
auto format by CI
oneflow-ci-bot Jul 11, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
196 changes: 94 additions & 102 deletions oneflow/api/python/framework/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,42 +38,31 @@ namespace one {

namespace {

template<typename T>
const DType* GetTensorDType(const T& tensor) {
const DType* GetTensorDType(const Tensor& tensor) {
return DType::Get(tensor.dtype()).GetOrThrow().get();
}

template<typename T>
struct TensorExportUtil final {};

template<>
struct TensorExportUtil<MirroredTensor> final {
static std::shared_ptr<MirroredTensor> MakeTensor(const std::shared_ptr<const Shape>& shape,
const DType* dtype,
const Symbol<Device>& device, bool is_lazy,
bool requires_grad, bool is_leaf) {
return MirroredTensor::MakeTensor(shape, dtype->data_type(), device, is_lazy, requires_grad,
is_leaf)
.GetPtrOrThrow();
}
};

template<>
struct TensorExportUtil<ConsistentTensor> final {
static std::shared_ptr<ConsistentTensor> MakeTensor(
const std::shared_ptr<const Shape>& shape, const DType* dtype,
const std::shared_ptr<const cfg::ParallelDistribution>& parallel_distribution,
const std::shared_ptr<const ParallelDesc>& parallel_desc, bool is_lazy, bool requires_grad,
bool is_leaf) {
return ConsistentTensor::MakeTensor(shape, dtype->data_type(), SymbolOf(*parallel_distribution),
SymbolOf(*parallel_desc), is_lazy, requires_grad, is_leaf)
.GetPtrOrThrow();
}
};
std::shared_ptr<Tensor> MakeLocalTensor(const std::shared_ptr<const Shape>& shape,
const DType* dtype, const Symbol<Device>& device,
bool is_lazy, bool requires_grad, bool is_leaf) {
return MirroredTensor::MakeTensor(shape, dtype->data_type(), device, is_lazy, requires_grad,
is_leaf)
.GetPtrOrThrow();
}

namespace {
std::shared_ptr<Tensor> MakeConsistentTensor(
const std::shared_ptr<const Shape>& shape, const DType* dtype,
Symbol<cfg::ParallelDistribution>& parallel_distribution, Symbol<ParallelDesc> parallel_desc,
bool is_lazy, bool requires_grad, bool is_leaf) {
return ConsistentTensor::MakeTensor(shape, dtype->data_type(), parallel_distribution,
parallel_desc, is_lazy, requires_grad, is_leaf)
.GetPtrOrThrow();
}

Maybe<void> EagerMirroredTensorZeros(const std::shared_ptr<MirroredTensor>& tensor) {
Maybe<void> EagerMirroredTensorZeros(const std::shared_ptr<Tensor>& t) {
const auto& tensor = std::dynamic_pointer_cast<MirroredTensor>(t);
CHECK_NOTNULL_OR_RETURN(tensor) << "local tensors supported only";
CHECK_OR_RETURN(tensor->is_eager()) << "eager tensors supported only";
JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> {
JUST(builder->AccessBlobByCallback(
tensor,
Expand All @@ -84,19 +73,21 @@ Maybe<void> EagerMirroredTensorZeros(const std::shared_ptr<MirroredTensor>& tens
"mut"));
return Maybe<void>::Ok();
}));

return Maybe<void>::Ok();
}

void ApiEagerMirroredTensorZeros(const std::shared_ptr<MirroredTensor>& tensor) {
void ApiEagerMirroredTensorZeros(const std::shared_ptr<Tensor>& tensor) {
return EagerMirroredTensorZeros(tensor).GetOrThrow();
}

template<typename T>
Maybe<void> CopyBetweenMirroredTensorAndNumpy(const std::shared_ptr<MirroredTensor>& tensor,
Maybe<void> CopyBetweenMirroredTensorAndNumpy(const std::shared_ptr<Tensor>& t,
py::array_t<T> array,
void (*Copy)(uint64_t, py::array_t<T>),
const std::string& modifier) {
const auto& tensor = std::dynamic_pointer_cast<MirroredTensor>(t);
CHECK_NOTNULL_OR_RETURN(tensor) << "local tensors supported only";
CHECK_OR_RETURN(tensor->is_eager()) << "eager tensors supported only";
std::atomic<bool> synced(false);

JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> {
Expand All @@ -118,15 +109,13 @@ Maybe<void> CopyBetweenMirroredTensorAndNumpy(const std::shared_ptr<MirroredTens
}

template<typename T>
void ApiCopyMirroredTensorToNumpy(const std::shared_ptr<MirroredTensor>& tensor,
py::array_t<T> array) {
void ApiCopyMirroredTensorToNumpy(const std::shared_ptr<Tensor>& tensor, py::array_t<T> array) {
return CopyBetweenMirroredTensorAndNumpy(tensor, array, OfBlob_CopyToBuffer, "const")
.GetOrThrow();
}

template<typename T>
void ApiCopyMirroredTensorFromNumpy(const std::shared_ptr<MirroredTensor>& tensor,
py::array_t<T> array) {
void ApiCopyMirroredTensorFromNumpy(const std::shared_ptr<Tensor>& tensor, py::array_t<T> array) {
return CopyBetweenMirroredTensorAndNumpy(tensor, array, OfBlob_CopyFromBuffer, "mut")
.GetOrThrow();
}
Expand Down Expand Up @@ -161,21 +150,22 @@ const std::string& ApiGetCopyMirroredTensorFromNumpyFuncName(const Tensor& tenso
return *GetCopyMirroredTensorFromNumpyFuncName(tensor.dtype()).GetPtrOrThrow();
}

Symbol<Device> TensorGetDevice(const MirroredTensor& tensor) {
return tensor.device().GetOrThrow();
}
Symbol<Device> TensorGetDevice(const Tensor& tensor) { return tensor.device().GetOrThrow(); }

std::shared_ptr<const ParallelDesc> TensorGetParallelDesc(const ConsistentTensor& tensor) {
return tensor.parallel_desc().GetOrThrow().shared_from_symbol();
Symbol<ParallelDesc> TensorGetParallelDesc(const Tensor& tensor) {
return tensor.parallel_desc().GetOrThrow();
}

std::tuple<std::vector<Shape>, std::vector<const DType*>> GetTensorBufferShapesAndDTypes(
const std::shared_ptr<MirroredTensor>& tensor) {
Maybe<std::tuple<std::vector<Shape>, std::vector<const DType*>>>
MaybeGetTensorBufferShapesAndDTypes(const std::shared_ptr<Tensor>& t) {
const auto& tensor = std::dynamic_pointer_cast<MirroredTensor>(t);
CHECK_NOTNULL_OR_RETURN(tensor) << "local tensors supported only";
CHECK_OR_RETURN(tensor->is_eager()) << "eager tensors supported only";
std::vector<Shape> shapes;
std::vector<const DType*> dtypes;
std::atomic<bool> synced(false);

CHECK_JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> {
JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> {
JUST(builder->AccessBlobByCallback(
tensor, [&synced](uint64_t of_blob_ptr) { synced = true; }, "const"));
return Maybe<void>::Ok();
Expand All @@ -185,71 +175,64 @@ std::tuple<std::vector<Shape>, std::vector<const DType*>> GetTensorBufferShapesA
while (!synced) {}
});

const Blob& blob = CHECK_JUST(tensor->eager_blob_object())->blob();
const Blob& blob = JUST(tensor->eager_blob_object())->blob();
const Shape& blob_shape = blob.static_shape();
const auto* tensor_buffer_ptr = blob.dptr<TensorBuffer>();
for (int64_t i = 0; i < blob_shape.elem_cnt(); ++i) {
const TensorBuffer* tensor_buffer = tensor_buffer_ptr + i;
shapes.push_back(tensor_buffer->shape());
dtypes.push_back(DType::Get(tensor_buffer->data_type()).GetOrThrow().get());
}

return std::make_tuple(shapes, dtypes);
}

} // namespace

void SpecializedDef(py::class_<MirroredTensor, Tensor, std::shared_ptr<MirroredTensor>>* api) {
using T = MirroredTensor;
api->def_property_readonly("device", &TensorGetDevice);
api->def_property_readonly("data", &T::data);
api->def_property_readonly("_tensor_buffer_shapes_and_dtypes", &GetTensorBufferShapesAndDTypes);
#define DEFINE_TENSOR_METHOD(T, type_proto) \
api->def("_copy_to_numpy_" #T, &ApiCopyMirroredTensorToNumpy<T>); \
api->def("_copy_from_numpy_" #T, &ApiCopyMirroredTensorFromNumpy<T>);
OF_PP_FOR_EACH_TUPLE(DEFINE_TENSOR_METHOD, POD_DATA_TYPE_SEQ);

#undef DEFINE_TENSOR_METHOD
api->def("_get_copy_mirrored_tensor_to_numpy_func_name",
&ApiGetCopyMirroredTensorToNumpyFuncName);
api->def("_get_copy_mirrored_tensor_from_numpy_func_name",
&ApiGetCopyMirroredTensorFromNumpyFuncName);
api->def("zeros_", &ApiEagerMirroredTensorZeros);
api->def("_register_hook",
[](const std::shared_ptr<MirroredTensor>& self, const AutogradMeta::Hook& hook) -> void {
if (!self->grad_fn_node()) { CHECK_JUST(AddAccumulateFunctionNode(self)); }
self->mut_autograd_meta()->add_hook(hook);
});
std::tuple<std::vector<Shape>, std::vector<const DType*>> GetTensorBufferShapesAndDTypes(
const std::shared_ptr<Tensor>& tensor) {
return MaybeGetTensorBufferShapesAndDTypes(tensor).GetOrThrow();
}

void SpecializedDef(py::class_<ConsistentTensor, Tensor, std::shared_ptr<ConsistentTensor>>* api) {
api->def_property_readonly("placement", &TensorGetParallelDesc);
Maybe<void> RegisterTensorHook(const std::shared_ptr<Tensor>& self,
const AutogradMeta::Hook& hook) {
if (!self->grad_fn_node()) { JUST(AddAccumulateFunctionNode(self)); }
self->mut_autograd_meta()->add_hook(hook);
return Maybe<void>::Ok();
}
void ApiRegisterTensorHook(const std::shared_ptr<Tensor>& self, const AutogradMeta::Hook& hook) {
return RegisterTensorHook(self, hook).GetOrThrow();
}

template<typename T>
void ExportTensor(py::module& m, const char* name) {
py::class_<T, Tensor, std::shared_ptr<T>> tensor_api(m, name);
tensor_api
.def(py::init(&TensorExportUtil<T>::MakeTensor))
} // namespace

ONEFLOW_API_PYBIND11_MODULE("", m) {
py::class_<Tensor, std::shared_ptr<Tensor>>(m, "Tensor")
.def(py::init(&MakeLocalTensor))
.def(py::init(&MakeConsistentTensor))
// Properties of pytorch
.def_property_readonly("shape", &T::shape)
.def_property_readonly("dtype", &GetTensorDType<T>)
.def_property_readonly("is_cuda", &T::is_cuda)
.def_property_readonly("grad", [](const T& t) { return t.api_acc_grad().GetPtrOrThrow(); })
.def_property_readonly("shape", &Tensor::shape)
.def_property_readonly("dtype", &GetTensorDType)
.def_property_readonly("is_cuda", &Tensor::is_cuda)
.def_property_readonly("grad",
[](const Tensor& t) -> std::shared_ptr<Tensor> {
if (t.has_autograd_meta()) {
return t.acc_grad().GetPtrOrThrow();
} else {
return std::shared_ptr<Tensor>();
}
})
// setter of grad
.def("set_grad",
[](T& t, const std::shared_ptr<T>& grad) {
[](Tensor& t, const std::shared_ptr<Tensor>& grad) {
if (t.is_leaf()) {
t.set_acc_grad(grad);
t.set_acc_grad(grad).GetOrThrow();
} else {
throw std::runtime_error("You can only change gradient of leaf tensors.");
}
})
.def_property_readonly("grad_fn", &T::grad_fn_node)
.def_property_readonly("is_leaf", &T::is_leaf)
.def_property_readonly("grad_fn", &Tensor::grad_fn_node)
.def_property_readonly("is_leaf", &Tensor::is_leaf)
.def_property(
"requires_grad", &T::requires_grad,
[](T& t, bool requires_grad) {
"requires_grad", &Tensor::requires_grad,
[](Tensor& t, bool requires_grad) {
if (t.is_leaf()) {
t.set_requires_grad(requires_grad);
} else {
Expand All @@ -258,23 +241,32 @@ void ExportTensor(py::module& m, const char* name) {
})
// Methods of pytorch
.def("retain_grad",
[](T& t) {
[](Tensor& t) {
if (!t.is_leaf()) { t.set_retain_grad(true).GetOrThrow(); }
})
.def("detach", [](const T& t) { return t.api_detach().GetPtrOrThrow(); })
.def("clone", [](const T& t) { return t.api_clone().GetPtrOrThrow(); })
.def("detach", [](const Tensor& t) { return t.detach().GetPtrOrThrow(); })
.def("clone", [](const Tensor& t) { return t.clone().GetPtrOrThrow(); })
// OneFlow tensor properties other than pytorch tensor
.def_property_readonly("is_lazy", &T::is_lazy)
.def_property_readonly("is_consistent", &T::is_consistent);
SpecializedDef(&tensor_api);
}

} // namespace

ONEFLOW_API_PYBIND11_MODULE("", m) {
py::class_<Tensor, std::shared_ptr<Tensor>>(m, "Tensor");
ExportTensor<MirroredTensor>(m, "LocalTensor");
ExportTensor<ConsistentTensor>(m, "ConsistentTensor");
.def_property_readonly("is_lazy", &Tensor::is_lazy)
.def_property_readonly("is_eager", &Tensor::is_eager)
.def_property_readonly("is_consistent", &Tensor::is_consistent)
.def_property_readonly("is_local", &Tensor::is_local)
.def("zeros_", &ApiEagerMirroredTensorZeros)
.def("_register_hook", &ApiRegisterTensorHook)
// local tensor only
.def_property_readonly("_tensor_buffer_shapes_and_dtypes", &GetTensorBufferShapesAndDTypes)
.def_property_readonly("device", &TensorGetDevice)
.def_property_readonly("data", &Tensor::data)
#define DEFINE_TENSOR_METHOD(T, type_proto) \
.def("_copy_to_numpy_" #T, &ApiCopyMirroredTensorToNumpy<T>) \
.def("_copy_from_numpy_" #T, &ApiCopyMirroredTensorFromNumpy<T>)
OF_PP_FOR_EACH_TUPLE(DEFINE_TENSOR_METHOD, POD_DATA_TYPE_SEQ)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个 indent 是有意的吗,看着也有些奇怪

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是有意的,因为重新格式化效果不好。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我改回自动格式化吧

#undef DEFINE_TENSOR_METHOD
.def("_get_copy_mirrored_tensor_to_numpy_func_name", &ApiGetCopyMirroredTensorToNumpyFuncName)
.def("_get_copy_mirrored_tensor_from_numpy_func_name",
&ApiGetCopyMirroredTensorFromNumpyFuncName)
// consistent tensor only
.def_property_readonly("placement", &TensorGetParallelDesc);
}

} // namespace one
Expand Down
13 changes: 7 additions & 6 deletions oneflow/core/framework/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,14 @@ int64_t MirroredTensor::dim(int64_t index) const { return shape()->At(index); }

int64_t MirroredTensor::nelement() const { return shape()->elem_cnt(); }

std::shared_ptr<MirroredTensor> MirroredTensor::data() const {
std::shared_ptr<Tensor> MirroredTensor::data() const {
std::shared_ptr<MirroredTensor> t = std::make_shared<MirroredTensor>(impl_);
return t;
}

Maybe<MirroredTensor> MirroredTensor::api_detach() const {
return std::make_shared<MirroredTensor>(JUST(impl_->detach()));
Maybe<Tensor> MirroredTensor::detach() const {
std::shared_ptr<Tensor> tensor = std::make_shared<MirroredTensor>(JUST(impl_->detach()));
return tensor;
}

Maybe<Tensor> MirroredTensor::clone() const {
Expand Down Expand Up @@ -117,13 +118,13 @@ int64_t ConsistentTensor::nelement() const { return shape()->elem_cnt(); }

int64_t ConsistentTensor::ndim() const { return shape()->NumAxes(); }

std::shared_ptr<ConsistentTensor> ConsistentTensor::data() const {
std::shared_ptr<Tensor> ConsistentTensor::data() const {
std::shared_ptr<ConsistentTensor> t = std::make_shared<ConsistentTensor>(impl_);
return t;
}

Maybe<ConsistentTensor> ConsistentTensor::api_detach() const {
std::shared_ptr<ConsistentTensor> t = std::make_shared<ConsistentTensor>(impl_);
Maybe<Tensor> ConsistentTensor::detach() const {
std::shared_ptr<Tensor> t = std::make_shared<ConsistentTensor>(impl_);
return t;
}

Expand Down