## MNIST in TF using TensorFlow-Nightly

This script takes the MNIST database from TensorFlow Datasets, through tensorflow-nightly. 

*tf-nightly* is a pip package built and released to PyPI every night. Therefore it contains the latest versions of each model.

**Workflow**
1. Download dataset using *tfds.load* (tensorflow dataset)
2. Save data as a *tfrecord* file
3. Load *tfrecord* and create *tf.data.Dataset*

Reference:
https://www.tensorflow.org/datasets/overview

In [1]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

import tensorflow_datasets as tfds

print("TF Version: ", tf.__version__)

TF Version:  2.2.0


In [2]:
# Print list of availible datasets (optional)
# tfds.list_builders()

In [3]:
# Load mnist from tfds and read directly from public GCS bucket
ds = tfds.load('mnist', split='train', shuffle_files=True, try_gcs=True)
assert isinstance(ds, tf.data.Dataset)
print(ds)

<_OptionsDataset shapes: {image: (28, 28, 1), label: ()}, types: {image: tf.uint8, label: tf.int64}>


In [4]:
# Create the tfrecord files (no-op if already exists)
builder = tfds.builder('mnist')
builder.download_and_prepare()

# Load the `tf.data.Dataset`
ds = builder.as_dataset(split='train', shuffle_files=True)
print(ds)

<_OptionsDataset shapes: {image: (28, 28, 1), label: ()}, types: {image: tf.uint8, label: tf.int64}>


### Iterate over Dataset

#### Method 1 - As dict:
By default, the *tf.data.Dataset* object contains a *dict* of *tf.Tensor*

#### Method 2 - As tuple (supervised):
When *as_supervised=True*, you can get a tuple containing features and labels 

#### Method 3 - As numpy:

#### Method 4 - As batched tf.Tensor: 
By using *batch_size=-1* you can load the full dataset in a single batch.
*tfds.load* returns a dict or tuple (when as_supervised=True)

**Warning: Make sure dataset can fit in memory and that all examples have same shape.**

Reference:
https://stackoverflow.com/questions/42480111/model-summary-in-pytorch

In [5]:
# Select iteration method:
method = 1

def as_dict():
    ds = tfds.load('mnist', split='train')
    ds = ds.take(1)  # Only take a single example

    for example in ds:  # example is `{'image': tf.Tensor, 'label': tf.Tensor}`
        print(list(example.keys()))
        image = example["image"]
        label = example["label"]
        print(image.shape, label)
    print("Iteration method: ", "As Dict")

def as_tuple():
    ds = tfds.load('mnist', split='train', as_supervised=True)
    ds = ds.take(1)

    for image, label in ds:  # example is (image, label)
        print(image.shape, label)
    print("Iteration method: ", "As Tuple")
    
def as_numpy():
    ds = tfds.load('mnist', split='train', as_supervised=True)
    ds = ds.take(1)

    for image, label in tfds.as_numpy(ds):
        print(type(image), type(label), label)
    print("Iteration method: ", "As Numpy")

def as_batched_tfTensor():
    image, label = tfds.as_numpy(
        tfds.load('mnist', split='test', batch_size=-1, as_supervised=True))
    print(type(image), image.shape)
    print("Iteration method: ", "As Batched tf.Tensor")
    
def iteration_method(method):
    switcher = {
        1: as_dict(),
        2: as_tuple(),
        3: as_numpy(),
        4: as_batched_tfTensor(),
    }

iteration_method(method)

['image', 'label']
(28, 28, 1) tf.Tensor(4, shape=(), dtype=int64)
Iteration method:  As Dict
(28, 28, 1) tf.Tensor(4, shape=(), dtype=int64)
Iteration method:  As Tuple
<class 'numpy.ndarray'> <class 'numpy.int64'> 4
Iteration method:  As Numpy
<class 'numpy.ndarray'> (10000, 28, 28, 1)
Iteration method:  As Batched tf.Tensor


### Visualization

*tf.data.Dataset* objects can be converted to *pandas.DataFrame* with *tfds.as_dataframe* so it can be visualized in Colab

1. Add *tfds.core.DatasetInfo* as 2nd arg to visualize images, audio, texts, videos, etc.
2. Use *ds.take(x)* to display first "x" examples. *pandas.DataFrame* will load the full dataset in-memory which can be very expensive.

In [6]:
ds, info = tfds.load('mnist', split='train', with_info=True)

tfds.as_dataframe(ds.take(4), info)

Unnamed: 0,image,label
0,,4
1,,1
2,,0
3,,7


In [7]:
import netron

# Question: Where can I access the model?
#netron.start(ds, port=8081)