diff --git a/paddle/parameter/tests/test_argument.cpp b/paddle/parameter/tests/test_argument.cpp index 81fe4ee397351..98ab013548734 100644 --- a/paddle/parameter/tests/test_argument.cpp +++ b/paddle/parameter/tests/test_argument.cpp @@ -42,7 +42,7 @@ TEST(Argument, poolSequenceWithStride) { CHECK_EQ(outStart[3], 4); CHECK_EQ(outStart[4], 7); - CHECK_EQ(stridePositions->getSize(), 8); + CHECK_EQ(stridePositions->getSize(), 8UL); auto result = reversed ? strideResultReversed : strideResult; for (int i = 0; i < 8; i++) { CHECK_EQ(stridePositions->getData()[i], result[i]); diff --git a/python/paddle/v2/dataset/common.py b/python/paddle/v2/dataset/common.py index 418b592a5ac63..9c614914b5e37 100644 --- a/python/paddle/v2/dataset/common.py +++ b/python/paddle/v2/dataset/common.py @@ -149,3 +149,57 @@ def reader(): yield line return reader + + +def convert(output_path, + reader, + num_shards, + name_prefix, + max_lines_to_shuffle=1000): + import recordio + import cPickle as pickle + import random + """ + Convert data from reader to recordio format files. + + :param output_path: directory in which output files will be saved. + :param reader: a data reader, from which the convert program will read data instances. + :param num_shards: the number of shards that the dataset will be partitioned into. + :param name_prefix: the name prefix of generated files. + :param max_lines_to_shuffle: the max lines numbers to shuffle before writing. + """ + + assert num_shards >= 1 + assert max_lines_to_shuffle >= 1 + + def open_writers(): + w = [] + for i in range(0, num_shards): + n = "%s/%s-%05d-of-%05d" % (output_path, name_prefix, i, + num_shards - 1) + w.append(recordio.writer(n)) + + return w + + def close_writers(w): + for i in range(0, num_shards): + w[i].close() + + def write_data(w, lines): + random.shuffle(lines) + for i, d in enumerate(lines): + d = pickle.dumps(d, pickle.HIGHEST_PROTOCOL) + w[i % num_shards].write(d) + + w = open_writers() + lines = [] + + for i, d in enumerate(reader()): + lines.append(d) + if i % max_lines_to_shuffle == 0 and i >= max_lines_to_shuffle: + write_data(w, lines) + lines = [] + continue + + write_data(w, lines) + close_writers(w) diff --git a/python/paddle/v2/dataset/tests/common_test.py b/python/paddle/v2/dataset/tests/common_test.py index f9815d4f9e1ee..cfa194eba38ea 100644 --- a/python/paddle/v2/dataset/tests/common_test.py +++ b/python/paddle/v2/dataset/tests/common_test.py @@ -57,6 +57,38 @@ def test_cluster_file_reader(self): for idx, e in enumerate(reader()): self.assertEqual(e, str("0")) + def test_convert(self): + record_num = 10 + num_shards = 4 + + def test_reader(): + def reader(): + for x in xrange(record_num): + yield x + + return reader + + path = tempfile.mkdtemp() + paddle.v2.dataset.common.convert(path, + test_reader(), num_shards, + 'random_images') + + files = glob.glob(path + '/random_images-*') + self.assertEqual(len(files), num_shards) + + recs = [] + for i in range(0, num_shards): + n = "%s/random_images-%05d-of-%05d" % (path, i, num_shards - 1) + r = recordio.reader(n) + while True: + d = r.read() + if d is None: + break + recs.append(d) + + recs.sort() + self.assertEqual(total, record_num) + if __name__ == '__main__': unittest.main()