diff --git a/datasets/mnist.py b/datasets/mnist.py index e36a68a..121b725 100644 --- a/datasets/mnist.py +++ b/datasets/mnist.py @@ -72,7 +72,7 @@ def export_mnist(src_path, dataset_path): class MNIST(object): - def __init__(self, dataset_path, num_threads=8, batch_size=128, + def __init__(self, dataset_path, num_threads=8, batch_size=100, shuffle=True, normalize=True, augment=True, one_hot=True): """ :param dataset_path: The dataset folder path. @@ -135,8 +135,7 @@ def _measure_mean_and_std(self): dataset = self.train_set.shuffle(buffer_size=num_samples) dataset = dataset.map( self._read_image_func, - num_threads=self.num_threads, - output_buffer_size=2 * self.batch_size) + num_parallel_calls=self.num_threads) dataset = dataset.batch(num_samples) iterator = dataset.make_one_shot_iterator() images, labels = iterator.get_next()