diff --git a/python/paddle/v2/dataset/__init__.py b/python/paddle/v2/dataset/__init__.py index 2e4beb6882789..90830515c1e8e 100644 --- a/python/paddle/v2/dataset/__init__.py +++ b/python/paddle/v2/dataset/__init__.py @@ -26,8 +26,9 @@ import wmt14 import mq2007 import flowers +import voc2012 __all__ = [ 'mnist', 'imikolov', 'imdb', 'cifar', 'movielens', 'conll05', 'sentiment' - 'uci_housing', 'wmt14', 'mq2007', 'flowers' + 'uci_housing', 'wmt14', 'mq2007', 'flowers', 'voc2012' ] diff --git a/python/paddle/v2/dataset/tests/voc2012_test.py b/python/paddle/v2/dataset/tests/voc2012_test.py new file mode 100644 index 0000000000000..31e72ebf5eac0 --- /dev/null +++ b/python/paddle/v2/dataset/tests/voc2012_test.py @@ -0,0 +1,42 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.v2.dataset.voc2012 +import unittest + + +class TestVOC(unittest.TestCase): + def check_reader(self, reader): + sum = 0 + label = 0 + for l in reader(): + self.assertEqual(l[0].size, 3 * l[1].size) + sum += 1 + return sum + + def test_train(self): + count = self.check_reader(paddle.v2.dataset.voc_seg.train()) + self.assertEqual(count, 2913) + + def test_test(self): + count = self.check_reader(paddle.v2.dataset.voc_seg.test()) + self.assertEqual(count, 1464) + + def test_val(self): + count = self.check_reader(paddle.v2.dataset.voc_seg.val()) + self.assertEqual(count, 1449) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/dataset/voc2012.py b/python/paddle/v2/dataset/voc2012.py new file mode 100644 index 0000000000000..617e212d67fbe --- /dev/null +++ b/python/paddle/v2/dataset/voc2012.py @@ -0,0 +1,85 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Image dataset for segmentation. +The 2012 dataset contains images from 2008-2011 for which additional +segmentations have been prepared. As in previous years the assignment +to training/test sets has been maintained. The total number of images +with segmentation has been increased from 7,062 to 9,993. +""" + +import tarfile +import io +import numpy as np +from paddle.v2.dataset.common import download +from paddle.v2.image import * +from PIL import Image + +__all__ = ['train', 'test', 'val'] + +VOC_URL = 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/\ +VOCtrainval_11-May-2012.tar' + +VOC_MD5 = '6cd6e144f989b92b3379bac3b3de84fd' +SET_FILE = 'VOCdevkit/VOC2012/ImageSets/Segmentation/{}.txt' +DATA_FILE = 'VOCdevkit/VOC2012/JPEGImages/{}.jpg' +LABEL_FILE = 'VOCdevkit/VOC2012/SegmentationClass/{}.png' + +CACHE_DIR = 'voc2012' + + +def reader_creator(filename, sub_name): + + tarobject = tarfile.open(filename) + name2mem = {} + for ele in tarobject.getmembers(): + name2mem[ele.name] = ele + + def reader(): + set_file = SET_FILE.format(sub_name) + sets = tarobject.extractfile(name2mem[set_file]) + for line in sets: + line = line.strip() + data_file = DATA_FILE.format(line) + label_file = LABEL_FILE.format(line) + data = tarobject.extractfile(name2mem[data_file]).read() + label = tarobject.extractfile(name2mem[label_file]).read() + data = Image.open(io.BytesIO(data)) + label = Image.open(io.BytesIO(label)) + data = np.array(data) + label = np.array(label) + yield data, label + + return reader + + +def train(): + """ + Create a train dataset reader containing 2913 images in HWC order. + """ + return reader_creator(download(VOC_URL, CACHE_DIR, VOC_MD5), 'trainval') + + +def test(): + """ + Create a test dataset reader containing 1464 images in HWC order. + """ + return reader_creator(download(VOC_URL, CACHE_DIR, VOC_MD5), 'train') + + +def val(): + """ + Create a val dataset reader containing 1449 images in HWC order. + """ + return reader_creator(download(VOC_URL, CACHE_DIR, VOC_MD5), 'val') diff --git a/python/setup.py.in b/python/setup.py.in index b1041f6102a56..65a26940d4d70 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -20,6 +20,7 @@ setup_requires=["requests", "matplotlib", "rarfile", "scipy>=0.19.0", + "Pillow", "nltk"] if '${CMAKE_SYSTEM_PROCESSOR}' not in ['arm', 'armv7-a', 'aarch64']: