Skip to content

Commit

Permalink
complete python reader op python side
Browse files Browse the repository at this point in the history
  • Loading branch information
sneaxiy committed Jun 27, 2018
1 parent 19fd071 commit 811eae7
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 4 deletions.
2 changes: 2 additions & 0 deletions paddle/fluid/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,8 @@ All parameter, weight, gradient are variables in Paddle.
});

py::class_<LoDTensorArray>(m, "LoDTensorArray")
.def("__init__",
[](LoDTensorArray &instance) { new (&instance) LoDTensorArray(); })
.def("__getitem__",
[](LoDTensorArray &self, size_t i) { return &self.at(i); },
py::return_value_policy::reference)
Expand Down
3 changes: 2 additions & 1 deletion python/paddle/fluid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
import transpiler
from param_attr import ParamAttr, WeightNormParamAttr
from data_feeder import DataFeeder
from core import LoDTensor, CPUPlace, CUDAPlace, CUDAPinnedPlace
from core import LoDTensor, LoDTensorArray, CPUPlace, CUDAPlace, CUDAPinnedPlace, Scope
from transpiler import DistributeTranspiler, InferenceTranspiler, \
memory_optimize, release_memory
from concurrency import (Go, make_channel, channel_send, channel_recv,
Expand Down Expand Up @@ -72,6 +72,7 @@
'backward',
'regularizer',
'LoDTensor',
'LoDTensorArray',
'CPUPlace',
'CUDAPlace',
'CUDAPinnedPlace',
Expand Down
89 changes: 86 additions & 3 deletions python/paddle/fluid/layers/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@
from layer_function_generator import generate_layer_fn, templatedoc

__all__ = [
'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file',
'open_files', 'read_file', 'shuffle', 'batch', 'double_buffer',
'random_data_generator', 'Preprocessor', 'load'
'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'Recv',
'open_recordio_file', 'open_files', 'read_file', 'shuffle', 'batch',
'double_buffer', 'random_data_generator', 'py_reader', 'Preprocessor',
'load'
]


Expand Down Expand Up @@ -431,6 +432,88 @@ def random_data_generator(low, high, shapes, lod_levels, for_parallel=True):
return monkey_patch_reader_methods(main_prog_var)


def py_reader(capacity, shapes, lod_levels, dtypes, for_parallel=True):
"""
Create a reader and blocking queue for data feeding in Python
This layer returns a Reader Variable and a BlockingQueue.
The BlockingQueue provides `push()` method to push a
`LoDTensorArray` object into the queue in Python side. In C++
side, the Reader Variable would invoke `pop()` method of the
queue to retrieve the feeding data. The process of feeding data
in Python side and fetching data in C++ side can run in parallel.
The BlockingQueue should be closed using `close()` method when
unused.
Args:
capacity(int): The maximum capacity of the BlockingQueue.
shapes(list): List of tuples which declaring data shapes.
lod_levels(list): List of ints which declaring data lod_level.
dtypes(list): List of strs which declaring data type.
for_parallel(Bool): Set it as True if you are going to run
subsequent operators in parallel.
Returns:
Variable: A Reader Variable from which we can get feeding data.
BlockingQueue: A blocking queue for data feeding.
Examples:
.. code-block:: python
reader, queue = fluid.layers.py_reader(
capacity=10,
shapes=[[-1,3,224,224], [-1,1]],
lod_levels=[0, 0],
dtypes=['float32', 'int64'])
# Via the reader, we can use 'read_file' layer to get data:
image, label = fluid.layers.read_file(reader)
# Via the blocking queue, we can feed data using threads
def feed_data(queue, feed_images, feed_labels):
for feed_image, feed_label in zip(feed_images, feed_labels):
data = core.LoDTensorArray()
data.append(feed_image)
data.append(feed_label)
queue.push(data)
thread = threading.Thread(target=feed_data, args=(queue, feed_images, feed_labels))
"""
dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes]
shape_concat = []
ranks = []

for shape in shapes:
shape_concat.extend(shape)
ranks.append(len(shape))

queue_name = unique_name('lod_tensor_blocking_queue')
var = global_scope().var(queue_name)
feed_queue = core.init_lod_tensor_blocking_queue(var, capacity, shapes)

startup_blk = default_startup_program().current_block()
startup_var = startup_blk.create_var(name=unique_name('create_py_reader'))
startup_blk.append_op(
type='create_py_reader',
inputs={'blocking_queue': queue_name},
outputs={'Out': [startup_var]},
attrs={
'shape_concat': shape_concat,
'lod_levels': lod_levels,
'ranks': ranks
})

startup_var.desc.set_dtypes(dtypes)
startup_var.persistable = True

main_prog_var = _copy_reader_var_(default_main_program().current_block(),
startup_var)

if for_parallel:
main_prog_var = parallel(reader=main_prog_var)

return monkey_patch_reader_methods(main_prog_var), feed_queue


def open_files(filenames,
shapes,
lod_levels,
Expand Down
99 changes: 99 additions & 0 deletions python/paddle/fluid/tests/unittests/test_py_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import paddle.fluid as fluid
import numpy as np
from threading import Thread


def feed_data(feed_queue, inputs):
for in_data in inputs:
feed_queue.push(in_data)


class TestPyReader(unittest.TestCase):
def setUp(self):
self.capacity = 10
self.batch_size_min = 10
self.batch_size_max = 20
self.shapes = [(-1, 3, 2, 1), (-1, 1)]
self.lod_levels = [0, 0]
self.dtypes = ['float32', 'int64']
self.pass_num = 20

def test_single_thread_main(self):
self.main(use_thread=False)

def test_multiple_thread_main(self):
self.main(use_thread=True)

def main(self, use_thread=False):
with fluid.program_guard(fluid.Program(), fluid.Program()):
place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
) else fluid.CPUPlace()
executor = fluid.Executor(place)

data_file, feed_queue = fluid.layers.py_reader(
capacity=self.capacity,
dtypes=self.dtypes,
lod_levels=self.lod_levels,
shapes=self.shapes)

read_out_data = fluid.layers.read_file(data_file)
self.inputs = []

for i in range(self.pass_num):
in_data = fluid.LoDTensorArray()
batch_size = np.random.random_integers(self.batch_size_min,
self.batch_size_max)
for shape, dtype in zip(self.shapes, self.dtypes):
next_data = np.random.uniform(
low=0, high=1000,
size=(batch_size, ) + shape[1:]).astype(dtype)
in_data.append(executor.as_lodtensor(next_data))

self.inputs.append(in_data)

executor.run(fluid.default_startup_program())
self.outputs = []
if use_thread:
thread = Thread(
target=feed_data, args=(feed_queue, self.inputs))
thread.start()
for in_data in self.inputs:
self.outputs.append(
executor.run(fetch_list=list(read_out_data)))
else:
for in_data in self.inputs:
feed_queue.push(in_data)
self.outputs.append(
executor.run(fetch_list=list(read_out_data)))

feed_queue.close()
self.validate()

def validate(self):
self.assertEqual(len(self.inputs), len(self.outputs))
for in_data_list, out_data_list in zip(self.inputs, self.outputs):
self.assertEqual(len(in_data_list), len(out_data_list))
in_data_list_np = [
np.array(in_lod_tensor) for in_lod_tensor in in_data_list
]
for in_data, out_data in zip(in_data_list_np, out_data_list):
self.assertTrue((in_data == out_data).all())


if __name__ == '__main__':
unittest.main()

0 comments on commit 811eae7

Please sign in to comment.