In [1]:
import tensorflow_datasets as tfds

## 1. tfds.load() : with_info=True 지정시 ds_info에 들어있는 정보

In [2]:
dataset, ds_info = tfds.load(name='mnist',
                             shuffle_files=True,
                             with_info=True)

[1mDownloading and preparing dataset mnist/3.0.1 (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /root/tensorflow_datasets/mnist/3.0.1...[0m


local data directory. If you'd instead prefer to read directly from our public
GCS bucket (recommended if you're running on GCP), you can instead pass
`try_gcs=True` to `tfds.load` or set `data_dir=gs://tfds-data/datasets`.



HBox(children=(FloatProgress(value=0.0, description='Dl Completed...', max=4.0, style=ProgressStyle(descriptio…



[1mDataset mnist downloaded and prepared to /root/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.[0m


In [3]:
dir(ds_info)

['__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_builder',
 '_compute_dynamic_properties',
 '_dataset_info_path',
 '_features',
 '_fully_initialized',
 '_info_proto',
 '_license_path',
 '_metadata',
 '_set_splits',
 '_splits',
 'as_json',
 'as_proto',
 'citation',
 'compute_dynamic_properties',
 'data_dir',
 'dataset_size',
 'description',
 'download_size',
 'features',
 'full_name',
 'homepage',
 'initialize_from_bucket',
 'initialized',
 'metadata',
 'name',
 'read_from_directory',
 'redistribution_info',
 'splits',
 'supervised_keys',
 'update_splits_if_different',
 'version',
 'write_to_directory']

In [4]:
ds_info.features

FeaturesDict({
    'image': Image(shape=(28, 28, 1), dtype=tf.uint8),
    'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10),
})

In [5]:
ds_info.splits

{'test': <tfds.core.SplitInfo num_examples=10000>,
 'train': <tfds.core.SplitInfo num_examples=60000>}

In [9]:
n_train = ds_info.splits['train'].num_examples
n_test = ds_info.splits['test'].num_examples

print(n_train, n_test)

60000 10000


 ## 2. tfds.load() : dataset의 형식

In [10]:
dataset = tfds.load(name='mnist', shuffle_files=True)

print(type(dataset))

<class 'dict'>


In [11]:
print(dataset.keys())
print(dataset.values())

dict_keys(['test', 'train'])
dict_values([<_OptionsDataset shapes: {image: (28, 28, 1), label: ()}, types: {image: tf.uint8, label: tf.int64}>, <_OptionsDataset shapes: {image: (28, 28, 1), label: ()}, types: {image: tf.uint8, label: tf.int64}>])


In [12]:
train_ds = dataset['train']
test_ds = dataset['test']

print(type(train_ds))
print(type(test_ds))

<class 'tensorflow.python.data.ops.dataset_ops._OptionsDataset'>
<class 'tensorflow.python.data.ops.dataset_ops._OptionsDataset'>


In [14]:
for tmp in train_ds:
    print(type(tmp))
    print(tmp.keys())

    break

<class 'dict'>
dict_keys(['image', 'label'])


In [15]:
for tmp in train_ds:
    images = tmp['image']
    labels = tmp['label']

    print(images.shape)
    print(labels.shape)

    break

(28, 28, 1)
()


In [17]:
# dataset에 batch를 지정할 경우, batch를 단위로 데이터를 뽑아온다

train_ds = dataset['train'].batch(32)
for tmp in train_ds:
    images = tmp['image']
    labels = tmp['label']

    print(images.shape)
    print(labels.shape)

    break


(32, 28, 28, 1)
(32,)


## 3. tfds.load() : as_supervised=True 지정

In [22]:
dataset = tfds.load(name='mnist', 
                    shuffle_files=True,
                    as_supervised=True) 
train_ds = dataset['train'].batch(32)
test_ds = dataset['test']

In [23]:
for tmp in train_ds:
    print(type(tmp))   # as_supervised=True 지정시 tuple이 들어있음

    break

<class 'tuple'>


In [24]:
for tmp in train_ds:
    images = tmp[0]
    labels = tmp[1]

    print(images.shape)
    print(labels.shape)

    break

(32, 28, 28, 1)
(32,)


In [25]:
# as_supervised=True 지정 시, 사용 가능한 표현
for images, labels in train_ds:
    print(images.shape)
    print(labels.shape)

    break

(32, 28, 28, 1)
(32,)


## 4. tfds.load(): split 지정

In [28]:
train_ds, test_ds = tfds.load(name='mnist', 
                    shuffle_files=True,
                    as_supervised=True,
                    split=['train', 'test']) 

train_ds = train_ds.batch(32)
for images, labels in train_ds:
    print(images.shape)
    print(labels.shape)
    break

(32, 28, 28, 1)
(32,)


### with_info=True, split 함께 지정시

In [29]:
# train_ds와 test_ds는 튜플로 묶어줘야 함

(train_ds, test_ds), ds_info = tfds.load(name='mnist', 
                    shuffle_files=True,
                    as_supervised=True,
                    split=['train', 'test'],
                    with_info=True) 
print(train_ds)
print(test_ds)
print(ds_info)

<_OptionsDataset shapes: ((28, 28, 1), ()), types: (tf.uint8, tf.int64)>
<_OptionsDataset shapes: ((28, 28, 1), ()), types: (tf.uint8, tf.int64)>
tfds.core.DatasetInfo(
    name='mnist',
    version=3.0.1,
    description='The MNIST database of handwritten digits.',
    homepage='http://yann.lecun.com/exdb/mnist/',
    features=FeaturesDict({
        'image': Image(shape=(28, 28, 1), dtype=tf.uint8),
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10),
    }),
    total_num_examples=70000,
    splits={
        'test': 10000,
        'train': 60000,
    },
    supervised_keys=('image', 'label'),
    citation="""@article{lecun2010mnist,
      title={MNIST handwritten digit database},
      author={LeCun, Yann and Cortes, Corinna and Burges, CJ},
      journal={ATT Labs [Online]. Available: http://yann.lecun.com/exdb/mnist},
      volume={2},
      year={2010}
    }""",
    redistribution_info=,
)



In [30]:
dataset, ds_info = tfds.load(name='patch_camelyon', 
                    shuffle_files=True,
                    as_supervised=True,
                    with_info=True)
print(ds_info)

[1mDownloading and preparing dataset patch_camelyon/2.0.0 (download: 7.48 GiB, generated: Unknown size, total: 7.48 GiB) to /root/tensorflow_datasets/patch_camelyon/2.0.0...[0m


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Completed...', max=1.0, style=Progre…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Size...', max=1.0, style=ProgressSty…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Extraction completed...', max=1.0, styl…









HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Shuffling and writing examples to /root/tensorflow_datasets/patch_camelyon/2.0.0.incompleteUMP7DA/patch_camelyon-test.tfrecord


HBox(children=(FloatProgress(value=0.0, max=32768.0), HTML(value='')))



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Shuffling and writing examples to /root/tensorflow_datasets/patch_camelyon/2.0.0.incompleteUMP7DA/patch_camelyon-train.tfrecord


HBox(children=(FloatProgress(value=0.0, max=262144.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Shuffling and writing examples to /root/tensorflow_datasets/patch_camelyon/2.0.0.incompleteUMP7DA/patch_camelyon-validation.tfrecord


HBox(children=(FloatProgress(value=0.0, max=32768.0), HTML(value='')))

[1mDataset patch_camelyon downloaded and prepared to /root/tensorflow_datasets/patch_camelyon/2.0.0. Subsequent calls will reuse this data.[0m
tfds.core.DatasetInfo(
    name='patch_camelyon',
    version=2.0.0,
    description='The PatchCamelyon benchmark is a new and challenging image classification
dataset. It consists of 327.680 color images (96 x 96px) extracted from
histopathologic scans of lymph node sections. Each image is annoted with a
binary label indicating presence of metastatic tissue. PCam provides a new
benchmark for machine learning models: bigger than CIFAR10, smaller than
Imagenet, trainable on a single GPU.',
    homepage='https://patchcamelyon.grand-challenge.org/',
    features=FeaturesDict({
        'id': Text(shape=(), dtype=tf.string),
        'image': Image(shape=(96, 96, 3), dtype=tf.uint8),
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=2),
    }),
    total_num_examples=327680,
    splits={
        'test': 32768,
        'train': 262144

In [32]:
print(ds_info.features, '\n')
print(ds_info.splits)

FeaturesDict({
    'id': Text(shape=(), dtype=tf.string),
    'image': Image(shape=(96, 96, 3), dtype=tf.uint8),
    'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=2),
}) 

{'test': <tfds.core.SplitInfo num_examples=32768>, 'train': <tfds.core.SplitInfo num_examples=262144>, 'validation': <tfds.core.SplitInfo num_examples=32768>}
