diff --git a/python/paddle/v2/reader/creator.py b/python/paddle/v2/reader/creator.py index 07142056f872d..9f888b16d6b2f 100644 --- a/python/paddle/v2/reader/creator.py +++ b/python/paddle/v2/reader/creator.py @@ -16,7 +16,7 @@ program. """ -__all__ = ['np_array', 'text_file'] +__all__ = ['np_array', 'text_file', "recordio"] def np_array(x): @@ -55,3 +55,24 @@ def reader(): f.close() return reader + + +def recordio(path): + """ + Creates a data reader that outputs record one one by one from given recordio file + :path: path of recordio file + :returns: data reader of recordio file + """ + + import recordio as rec + + def reader(): + f = rec.reader(path) + while True: + r = f.read() + if r is None: + break + yield r + f.close() + + return reader diff --git a/python/paddle/v2/reader/tests/creator_test.py b/python/paddle/v2/reader/tests/creator_test.py index 359f3eeefbe8e..ba4f558874a01 100644 --- a/python/paddle/v2/reader/tests/creator_test.py +++ b/python/paddle/v2/reader/tests/creator_test.py @@ -34,5 +34,14 @@ def test_text_file(self): self.assertEqual(e, str(idx * 2) + " " + str(idx * 2 + 1)) +class TestRecordIO(unittest.TestCase): + def test_recordio(self): + path = os.path.join( + os.path.dirname(__file__), "test_recordio_creator.dat") + reader = paddle.v2.reader.creator.recordio(path) + for idx, r in enumerate(reader()): + self.assertSequenceEqual(r, str(idx)) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/v2/reader/tests/test_recordio_creator.dat b/python/paddle/v2/reader/tests/test_recordio_creator.dat new file mode 100644 index 0000000000000..17aa89b679618 Binary files /dev/null and b/python/paddle/v2/reader/tests/test_recordio_creator.dat differ