In [1]:
import tensorflow as tf
from tensorflow import keras

In [2]:
(x_train, y_train), (x_val, y_val) = keras.datasets.fashion_mnist.load_data()

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz


In [3]:
def preprocess(x, y):
  x = tf.cast(x, tf.float32) / 255.0
  y = tf.cast(y, tf.int64)
  return x, y

def create_dataset(xs, ys, n_classes=10):
  ys = tf.one_hot(ys, depth=n_classes)
  return tf.data.Dataset.from_tensor_slices((xs, ys)) \
    .map(preprocess) \
    .shuffle(len(ys)) \
    .batch(128)

In [4]:
train_dataset = create_dataset(x_train, y_train)
val_dataset = create_dataset(x_val, y_val)

In [5]:
model = keras.Sequential([
    keras.layers.Reshape(target_shape=(28 * 28,), input_shape=(28, 28)),
    keras.layers.Dense(units=256, activation='relu'),
    keras.layers.Dense(units=192, activation='relu'),
    keras.layers.Dense(units=128, activation='relu'),
    keras.layers.Dense(units=10, activation='softmax')
])

In [6]:
model.compile(optimizer='adam', 
              loss=tf.losses.CategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

history = model.fit(
    train_dataset.repeat(), 
    epochs=10, 
    steps_per_epoch=500,
    validation_data=val_dataset.repeat(), 
    validation_steps=2
)

Train for 500 steps, validate for 2 steps
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [12]:
predictions = model.predict(val_dataset)
predictions

array([[1.9278690e-09, 3.4556641e-12, 9.3269065e-21, ..., 1.6218377e-20,
        7.9097800e-15, 5.8199265e-15],
       [7.0153397e-11, 7.9746387e-09, 9.9999988e-01, ..., 1.3156461e-12,
        1.1792189e-07, 8.0918811e-17],
       [9.9958652e-01, 4.0543827e-04, 1.6399834e-08, ..., 7.5970439e-09,
        1.6763241e-08, 1.6014641e-07],
       ...,
       [1.5513769e-10, 1.0470096e-12, 3.8800562e-16, ..., 1.0000000e+00,
        1.1387511e-08, 1.3016335e-12],
       [3.3407903e-30, 0.0000000e+00, 2.6473563e-28, ..., 2.9416539e-30,
        5.0572662e-24, 9.2392079e-25],
       [1.6515025e-24, 1.4187717e-32, 3.5361581e-26, ..., 3.3701475e-26,
        3.7973083e-17, 3.3709147e-21]], dtype=float32)

In [13]:
import numpy as np
np.argmax(predictions[0])

3