In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow import keras
from image_plotter import show_image, show_images, show_image_probs
from bokeh.io import output_notebook
import numpy as np

In [2]:
output_notebook()

In [3]:
auto = tf.data.experimental.AUTOTUNE

# Load Data

In [4]:
fashion_mnist, fashion_mnist_info = tfds.load("fashion_mnist", data_dir="/data", with_info=True)
fashion_mnist_info



tfds.core.DatasetInfo(
    name='fashion_mnist',
    version=1.0.0,
    description='Fashion-MNIST is a dataset of Zalando's article images consisting of a training set of 60,000 examples and a test set of 10,000 examples. Each example is a 28x28 grayscale image, associated with a label from 10 classes.',
    urls=['https://github.com/zalandoresearch/fashion-mnist'],
    features=FeaturesDict({
        'image': Image(shape=(28, 28, 1), dtype=tf.uint8),
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10),
    }),
    total_num_examples=70000,
    splits={
        'test': 10000,
        'train': 60000,
    },
    supervised_keys=('image', 'label'),
    citation="""@article{DBLP:journals/corr/abs-1708-07747,
      author    = {Han Xiao and
                   Kashif Rasul and
                   Roland Vollgraf},
      title     = {Fashion-MNIST: a Novel Image Dataset for Benchmarking Machine Learning
                   Algorithms},
      journal   = {CoRR},
      volume  

In [5]:
classes = fashion_mnist_info.features["label"].names
classes

['T-shirt/top',
 'Trouser',
 'Pullover',
 'Dress',
 'Coat',
 'Sandal',
 'Shirt',
 'Sneaker',
 'Bag',
 'Ankle boot']

In [6]:
testset = tfds.load("fashion_mnist", data_dir="/data", split=tfds.Split.TEST)

first_80_percent = tfds.Split.TRAIN.subsplit(tfds.percent[:80])
trainset = tfds.load("fashion_mnist", data_dir="/data", split=first_80_percent)

last_20_percent = tfds.Split.TRAIN.subsplit(tfds.percent[-20:])
valset = tfds.load("fashion_mnist", data_dir="/data", split=last_20_percent)

In [7]:
images = []
targets = []
for elem in trainset.take(5):
    images.append(elem["image"].numpy().squeeze())
    targets.append(elem["label"])

In [8]:
images[0].shape

(28, 28)

In [9]:
images[0].dtype

dtype('uint8')

In [10]:
show_images(images=images, label_idxs=targets, classes=classes, height=100, width=100)

# Build the Model

In [11]:
model = keras.Sequential([
    keras.layers.Flatten(input_shape=(28, 28)),
    keras.layers.Dense(128, activation="relu"),
    keras.layers.Dense(10, activation="softmax")
])
model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
flatten (Flatten)            (None, 784)               0         
_________________________________________________________________
dense (Dense)                (None, 128)               100480    
_________________________________________________________________
dense_1 (Dense)              (None, 10)                1290      
Total params: 101,770
Trainable params: 101,770
Non-trainable params: 0
_________________________________________________________________


# Build the Data Pipeline 

In [12]:
def to_tpl(elem):
    return elem["image"], elem["label"]

def squeeze(image, label):
    return tf.squeeze(image), label

def normalize(image, label):
    return image/255, label

train_ds = trainset.map(to_tpl, auto)
train_ds = train_ds.map(squeeze, auto)
train_ds = train_ds.map(normalize, auto)
train_ds = train_ds.shuffle(512)
# train_ds = train_ds.repeat()
train_ds = train_ds.batch(32)

In [13]:
one_batch = None
for images, labels in train_ds.take(1):
    one_batch = images
one_batch.shape

TensorShape([32, 28, 28])

In [14]:
logits = model(one_batch)
logits.shape

TensorShape([32, 10])

# Train the Model

In [15]:
model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])

In [None]:
# Dont need to do this if the dataset is not infinitely repeating.
num_batches = (60000 * 0.8)/32
num_batches
model.fit(train_ds, epochs=3, steps_per_epoch=num_batches)

In [16]:
model.fit(train_ds, epochs=3)

W0816 16:53:06.542076 4531324352 deprecation.py:323] From /Users/avilay/.venvs/ai/lib/python3.7/site-packages/tensorflow/python/ops/math_grad.py:1250: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


Epoch 1/3
Epoch 2/3
Epoch 3/3


<tensorflow.python.keras.callbacks.History at 0x13c7cd748>

# Evaluate the Model

In [17]:
test_ds = testset.map(to_tpl, auto)
test_ds = test_ds.map(squeeze, auto)
test_ds = test_ds.map(normalize, auto)
test_loss, test_acc = model.evaluate(test_ds.batch(32))

    313/Unknown - 2s 5ms/step - loss: 0.3689 - accuracy: 0.8680

In [18]:
print(f"Test: Loss={test_loss:.3f} Accuracy={test_acc*100:.2f}%")

Test: Loss=0.369 Accuracy=86.80%


In [19]:
wrong_images = []
wrong_label_idxs = []
wrong_preds = []
for elem in testset:
    image = elem["image"].numpy().squeeze()
    label_idx = int(elem["label"].numpy().squeeze())
    prob_dist = model(np.expand_dims(image, axis=0)).numpy().squeeze()
    pred_label_idx = int(np.argmax(prob_dist))
    if pred_label_idx != label_idx:
        wrong_images.append(image)
        wrong_label_idxs.append(label_idx)
        wrong_preds.append(prob_dist)

In [20]:
len(wrong_images)

1540

In [22]:
idxs = np.random.choice(len(wrong_images), 5)
images = [wrong_images[x] for x in idxs]
label_idxs = [wrong_label_idxs[x] for x in idxs]
preds = [wrong_preds[x] for x in idxs]
show_image_probs(images=images, label_idxs=label_idxs, probs=preds, classes=classes, height=150, width=150)

# Predict with the Model

In [23]:
images = []
label_idxs = []
for elem in testset.take(5):
    images.append(elem["image"].numpy().squeeze())
    label_idxs.append(elem["label"])

In [24]:
show_images(images=images, label_idxs=label_idxs, classes=classes, height=100, width=100)

In [25]:
preds = []
for image in images:
    image = np.expand_dims(image, axis=0)  # Create a batch of 1
    pred = model(image)
    preds.append(pred.numpy().squeeze())

In [26]:
show_image_probs(images=images, probs=preds, label_idxs=label_idxs, classes=classes, height=150, width=150)