Skip to content

Commit

Permalink
Adds data_ptr method to Tensor and TensorList (#1773)
Browse files Browse the repository at this point in the history
- adds an ability to expose a raw data pointer form Tensor and TensorList on CPU and GPU
- adds tests for `data_ptr` which created numpy or cupy arrays form an exposed pointers

Signed-off-by: Janusz Lisiecki <jlisiecki@nvidia.com>
  • Loading branch information
JanuszL committed Mar 2, 2020
1 parent ca78df2 commit 03373f9
Show file tree
Hide file tree
Showing 9 changed files with 211 additions and 52 deletions.
99 changes: 61 additions & 38 deletions dali/python/backend_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,12 @@ void ExposeTensor(py::module &m) {
},
R"code(
String representing NumPy type of the Tensor.
)code")
.def("data_ptr", [](Tensor<CPUBackend> &t) {
return py::reinterpret_borrow<py::object>(PyLong_FromVoidPtr(t.raw_mutable_data()));
},
R"code(
Returns the address of the first element of tensor.
)code");

py::class_<Tensor<GPUBackend>>(m, "TensorGPU")
Expand Down Expand Up @@ -220,6 +226,13 @@ void ExposeTensor(py::module &m) {
},
R"code(
String representing NumPy type of the Tensor.
)code")
.def("data_ptr",
[](Tensor<GPUBackend> &t) {
return py::reinterpret_borrow<py::object>(PyLong_FromVoidPtr(t.raw_mutable_data()));
},
R"code(
Returns the address of the first element of tensor.
)code");
}

Expand Down Expand Up @@ -254,11 +267,6 @@ py::tuple TensorListGetItemSliceImpl(TensorList<Backend> &t, py::slice slice) {
#endif

void ExposeTensorList(py::module &m) {
// We only want to wrap buffers w/ TensorLists to feed then to
// the backend. We do not support converting from TensorLists
// to numpy arrays currently.


py::class_<TensorList<CPUBackend>>(m, "TensorListCPU", py::buffer_protocol())
.def(py::init([](py::buffer b, string layout = "") {
// We need to verify that the input data is C_CONTIGUOUS
Expand Down Expand Up @@ -302,26 +310,27 @@ void ExposeTensorList(py::module &m) {
layout : the layout description
)code")
.def("layout", &TensorList<CPUBackend>::GetLayout)
.def("at", [](TensorList<CPUBackend> &t, Index id) -> py::array {
DALI_ENFORCE(IsValidType(t.type()), "Cannot produce "
.def("at", [](TensorList<CPUBackend> &tl, Index id) -> py::array {
DALI_ENFORCE(IsValidType(tl.type()), "Cannot produce "
"buffer info for tensor w/ invalid type.");
DALI_ENFORCE(static_cast<size_t>(id) < t.ntensor(), "Index is out-of-range.");
DALI_ENFORCE(static_cast<size_t>(id) < tl.ntensor(), "Index is out-of-range.");
DALI_ENFORCE(id >= 0, "Index is out-of-range.");

std::vector<ssize_t> shape(t.tensor_shape(id).size()), stride(t.tensor_shape(id).size());
std::vector<ssize_t> shape(tl.tensor_shape(id).size()),
stride(tl.tensor_shape(id).size());
size_t dim_prod = 1;
for (size_t i = 0; i < shape.size(); ++i) {
shape[i] = t.tensor_shape(id)[i];
shape[i] = tl.tensor_shape(id)[i];

// We iterate over stride backwards
stride[(stride.size()-1) - i] = t.type().size()*dim_prod;
dim_prod *= t.tensor_shape(id)[(shape.size()-1) - i];
stride[(stride.size()-1) - i] = tl.type().size()*dim_prod;
dim_prod *= tl.tensor_shape(id)[(shape.size()-1) - i];
}

return py::array(py::buffer_info(
t.raw_mutable_tensor(id),
t.type().size(),
FormatStrFromType(t.type()),
tl.raw_mutable_tensor(id),
tl.type().size(),
FormatStrFromType(tl.type()),
shape.size(), shape, stride));
},
R"code(
Expand All @@ -331,8 +340,8 @@ void ExposeTensorList(py::module &m) {
----------
)code")
.def("__getitem__",
[](TensorList<CPUBackend> &t, Index i) -> std::unique_ptr<Tensor<CPUBackend>> {
return TensorListGetItemImpl(t, i);
[](TensorList<CPUBackend> &tl, Index i) -> std::unique_ptr<Tensor<CPUBackend>> {
return TensorListGetItemImpl(tl, i);
},
"i"_a,
R"code(
Expand All @@ -344,8 +353,8 @@ void ExposeTensorList(py::module &m) {
py::keep_alive<0, 1>())
#if 0 // TODO(spanev): figure out which return_value_policy to choose
.def("__getitem__",
[](TensorList<CPUBackend> &t, py::slice slice) -> py::tuple {
return TensorListGetItemSliceImpl(t, slice);
[](TensorList<CPUBackend> &tl, py::slice slice) -> py::tuple {
return TensorListGetItemSliceImpl(tl, slice);
},
R"code(
Returns a tensor at given position in the list.
Expand All @@ -354,45 +363,45 @@ void ExposeTensorList(py::module &m) {
----------
)code")
#endif
.def("as_array", [](TensorList<CPUBackend> &t) -> py::array {
.def("as_array", [](TensorList<CPUBackend> &tl) -> py::array {
void* raw_mutable_data = nullptr;
std::string format;
size_t type_size;

if (t.size() > 0) {
DALI_ENFORCE(IsValidType(t.type()), "Cannot produce "
if (tl.size() > 0) {
DALI_ENFORCE(IsValidType(tl.type()), "Cannot produce "
"buffer info for tensor w/ invalid type.");
DALI_ENFORCE(t.IsDenseTensor(),
DALI_ENFORCE(tl.IsDenseTensor(),
"Tensors in the list must have the same shape");
raw_mutable_data = t.raw_mutable_data();
raw_mutable_data = tl.raw_mutable_data();
}

if (IsValidType(t.type())) {
format = FormatStrFromType(t.type());
type_size = t.type().size();
if (IsValidType(tl.type())) {
format = FormatStrFromType(tl.type());
type_size = tl.type().size();
} else {
// Default is float
format = py::format_descriptor<float>::format();
type_size = sizeof(float);
}

auto shape_size = t.shape().size() > 0 ? t.tensor_shape(0).size() : 0;
auto shape_size = tl.shape().size() > 0 ? tl.tensor_shape(0).size() : 0;
std::vector<ssize_t> shape(shape_size + 1);
std::vector<ssize_t> strides(shape_size + 1);
size_t dim_prod = 1;
for (size_t i = 0; i < shape.size(); ++i) {
if (i == 0) {
shape[i] = t.shape().size();
shape[i] = tl.shape().size();
} else {
shape[i] = t.tensor_shape(0)[i - 1];
shape[i] = tl.tensor_shape(0)[i - 1];
}

// We iterate over stride backwards
strides[(strides.size()-1) - i] = type_size*dim_prod;
if (i == shape.size() - 1) {
dim_prod *= t.shape().size();
dim_prod *= tl.shape().size();
} else {
dim_prod *= t.tensor_shape(0)[(shape.size()-2) - i];
dim_prod *= tl.tensor_shape(0)[(shape.size()-2) - i];
}
}

Expand All @@ -404,8 +413,8 @@ void ExposeTensorList(py::module &m) {
Parameters
----------
)code")
.def("__len__", [](TensorList<CPUBackend> &t) {
return t.ntensor();
.def("__len__", [](TensorList<CPUBackend> &tl) {
return tl.ntensor();
})
.def("is_dense_tensor", &TensorList<CPUBackend>::IsDenseTensor,
R"code(
Expand All @@ -417,8 +426,8 @@ void ExposeTensorList(py::module &m) {
may be viewed as a tensor of shape `(N, H, W, C)`.
)code")
.def("copy_to_external",
[](TensorList<CPUBackend> &t, py::object p) {
CopyToExternalTensor(&t, ctypes_void_ptr(p), CPU, 0);
[](TensorList<CPUBackend> &tl, py::object p) {
CopyToExternalTensor(&tl, ctypes_void_ptr(p), CPU, 0);
},
R"code(
Copy the contents of this `TensorList` to an external pointer
Expand Down Expand Up @@ -447,7 +456,14 @@ void ExposeTensorList(py::module &m) {
This function can only be called if `is_dense_tensor` returns `True`.
)code",
py::return_value_policy::reference_internal);
py::return_value_policy::reference_internal)
.def("data_ptr",
[](TensorList<CPUBackend> &tl) {
return py::reinterpret_borrow<py::object>(PyLong_FromVoidPtr(tl.raw_mutable_data()));
},
R"code(
Returns the address of the first element of TensorList.
)code");

py::class_<TensorList<GPUBackend>>(m, "TensorListGPU", py::buffer_protocol())
.def(py::init([]() {
Expand Down Expand Up @@ -567,7 +583,14 @@ void ExposeTensorList(py::module &m) {
This function can only be called if `is_dense_tensor` returns `True`.
)code",
py::return_value_policy::reference_internal);
py::return_value_policy::reference_internal)
.def("data_ptr",
[](TensorList<GPUBackend> &tl) {
return py::reinterpret_borrow<py::object>(PyLong_FromVoidPtr(tl.raw_mutable_data()));
},
R"code(
Returns the address of the first element of TensorList.
)code");
}

#define GetRegisteredOpsFor(OPTYPE) \
Expand Down
22 changes: 20 additions & 2 deletions dali/test/python/test_backend_impl.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -13,9 +13,13 @@
# limitations under the License.

from nvidia.dali.backend_impl import *
from nvidia.dali.pipeline import Pipeline
import nvidia.dali.ops as ops
import numpy as np
from numpy.testing import assert_array_equal, assert_allclose
from nose.tools import assert_raises
from test_utils import py_buffer_from_address


def test_create_tensor():
arr = np.random.rand(3, 5, 6)
Expand All @@ -41,7 +45,7 @@ def test_empty_tensor_tensorlist():
assert(np.array(tensor).shape == (0,))
assert(tensorlist.as_array().shape == (0,))

def test_tensorlist_getitem():
def test_tensorlist_getitem_cpu():
arr = np.random.rand(3, 5, 6)
tensorlist = TensorListCPU(arr, "NHWC")
list_of_tensors = [x for x in tensorlist]
Expand All @@ -56,6 +60,20 @@ def test_tensorlist_getitem():
with assert_raises(IndexError):
tensorlist[-len(tensorlist) - 1]

def test_data_ptr_tensor_cpu():
arr = np.random.rand(3, 5, 6)
tensor = TensorCPU(arr, "NHWC")
from_tensor = py_buffer_from_address(tensor.data_ptr(), tensor.shape(), tensor.dtype())
assert(np.array_equal(arr, from_tensor))


def test_data_ptr_tensor_list_cpu():
arr = np.random.rand(3, 5, 6)
tensorlist = TensorListCPU(arr, "NHWC")
tensor = tensorlist.as_tensor()
from_tensor_list = py_buffer_from_address(tensorlist.data_ptr(), tensor.shape(), tensor.dtype())
assert(np.array_equal(arr, from_tensor_list))

#if 0 // TODO(spanev): figure out which return_value_policy to choose
#def test_tensorlist_getitem_slice():
# arr = np.random.rand(3, 5, 6)
Expand Down
69 changes: 69 additions & 0 deletions dali/test/python/test_backend_impl_gpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from nvidia.dali.backend_impl import *
from nvidia.dali.pipeline import Pipeline
import nvidia.dali.ops as ops
import numpy as np
from nose.tools import assert_raises
import cupy as cp
from test_utils import py_buffer_from_address

class ExternalSourcePipe(Pipeline):
def __init__(self, batch_size, data):
super(ExternalSourcePipe, self).__init__(batch_size, 1, 0)
self.output = ops.ExternalSource(device="gpu")
self.data = data

def define_graph(self):
self.out = self.output()
return self.out

def iter_setup(self):
self.feed_input(self.out, self.data)

def test_tensorlist_getitem_gpu():
arr = np.random.rand(3, 5, 6)
pipe = ExternalSourcePipe(arr.shape[0], arr)
pipe.build()
tensorlist = pipe.run()[0]
list_of_tensors = [x for x in tensorlist]

assert(type(tensorlist[0]) != cp.ndarray)
assert(type(tensorlist[0]) == TensorGPU)
assert(type(tensorlist[-3]) == TensorGPU)
assert(len(list_of_tensors) == len(tensorlist))
with assert_raises(IndexError):
tensorlist[len(tensorlist)]
with assert_raises(IndexError):
tensorlist[-len(tensorlist) - 1]

def test_data_ptr_tensor_gpu():
arr = np.random.rand(3, 5, 6)
pipe = ExternalSourcePipe(arr.shape[0], arr)
pipe.build()
tensor = pipe.run()[0][0]
from_tensor = py_buffer_from_address(tensor.data_ptr(), tensor.shape(), tensor.dtype(), gpu=True)
# from_tensor is cupy array, convert arr to cupy as well
assert(cp.allclose(cp.array(arr[0]), from_tensor))

def test_data_ptr_tensor_list_gpu():
arr = np.random.rand(3, 5, 6)
pipe = ExternalSourcePipe(arr.shape[0], arr)
pipe.build()
tensor_list = pipe.run()[0]
tensor = tensor_list.as_tensor()
from_tensor = py_buffer_from_address(tensor_list.data_ptr(), tensor.shape(), tensor.dtype(), gpu=True)
# from_tensor is cupy array, convert arr to cupy as well
assert(cp.allclose(cp.array(arr), from_tensor))
17 changes: 17 additions & 0 deletions dali/test/python/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def get_dali_extra_path():
np = None
assert_array_equal = None
assert_allclose = None
cp = None

def import_numpy():
global np
global assert_array_equal
Expand Down Expand Up @@ -214,3 +216,18 @@ def dali_type(t):
if t is np.int32:
return types.INT32
raise "Unsupported type: " + str(t)

def py_buffer_from_address(address, shape, dtype, gpu = False):
buff = {'data': (address, False), 'shape': tuple(shape), 'typestr': dtype}
class py_holder(object):
pass

holder = py_holder()
holder.__array_interface__ = buff
holder.__cuda_array_interface__ = buff
if not gpu:
return np.array(holder, copy=False)
else:
global cp
import cupy as cp
return cp.asanyarray(holder)
12 changes: 2 additions & 10 deletions qa/TL0_python-self-test/test.sh
100755 → 100644
Original file line number Diff line number Diff line change
@@ -1,11 +1,3 @@
#!/bin/bash -e
# used pip packages
pip_packages="nose numpy opencv-python pillow librosa"
target_dir=./dali/test/python

# test_body definition is in separate file so it can be used without setup
source test_body.sh

pushd ../..
source ./qa/test_template.sh
popd
./test_nofw.sh
./test_cupy.sh
12 changes: 11 additions & 1 deletion qa/TL0_python-self-test/test_body.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,17 @@ test_py() {
python test_coco_tfrecord.py -i 64
}

test_body() {
test_no_fw() {
test_nose
test_py
}

test_cupy() {
nosetests --verbose test_backend_impl_gpu.py
}


run_all() {
test_no_fw
test_cupy
}
Loading

0 comments on commit 03373f9

Please sign in to comment.