# First Tensorflow notebook

In [20]:
import tensorflow as tf

## 1. Import the database of handwritten digits

In [21]:
mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

## 2. Build the Sequential model

In [22]:
model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10)
])

## 3. Get predictions

In [23]:
predictions = model(x_train[:1]).numpy()
tf.nn.softmax(predictions).numpy()

array([[0.12611556, 0.05724252, 0.12868938, 0.09790985, 0.07785591,
        0.10759037, 0.18333958, 0.10116836, 0.03791643, 0.0821721 ]],
      dtype=float32)

## 4. Calculate loss

In [24]:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
loss_fn(y_train[:1], predictions).numpy()
model.compile(optimizer='adam', loss=loss_fn, metrics=['accuracy'])

In [25]:
model.fit(x_train, y_train, epochs=5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tensorflow.python.keras.callbacks.History at 0x7f669ca5f5f8>

In [26]:
model.evaluate(x_test, y_test, verbose=2)

313/313 - 0s - loss: 0.0776 - accuracy: 0.9762


[0.07755856961011887, 0.9761999845504761]

In [27]:
probability_model = tf.keras.Sequential([
    model,
    tf.keras.layers.Softmax()
])

In [28]:
probability_model(x_test[:5])

<tf.Tensor: shape=(5, 10), dtype=float32, numpy=
array([[3.7031214e-07, 5.2762100e-08, 9.5548294e-06, 1.0840814e-03,
        1.8826060e-10, 7.0718187e-08, 4.7274710e-12, 9.9889272e-01,
        8.5149497e-07, 1.2417704e-05],
       [6.8114875e-11, 1.7057190e-04, 9.9982721e-01, 9.6886322e-07,
        1.8944097e-15, 1.2860875e-06, 1.1938640e-08, 4.2351871e-14,
        1.6392298e-09, 1.1698577e-15],
       [6.7730880e-08, 9.9896204e-01, 1.9389110e-04, 5.6754325e-06,
        2.9350609e-05, 1.5778644e-05, 1.0009665e-04, 6.7304203e-04,
        1.8923180e-05, 1.2099262e-06],
       [9.9982327e-01, 4.2692450e-09, 9.7346063e-05, 1.5308786e-06,
        4.5259114e-07, 1.7424724e-06, 5.9643098e-05, 1.5698120e-06,
        6.4178010e-09, 1.4417438e-05],
       [2.4470530e-06, 2.7956071e-09, 2.1016313e-05, 1.4615792e-07,
        9.9900156e-01, 5.4628947e-07, 5.4203538e-06, 1.5730459e-04,
        1.1754477e-06, 8.1035960e-04]], dtype=float32)>