In [13]:
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

from tensorflow import keras
from keras.models import Sequential
from keras.layers import Dense,Flatten,InputLayer

In [7]:
mnist_dataset, mnist_info =tfds.load(name='mnist',with_info=True,as_supervised=True)

mnist_train,mnist_test = mnist_dataset['train'], mnist_dataset['test']

num_validation_samples = 0.1* mnist_info.splits['train'].num_examples
num_validation_samples = tf.cast(num_validation_samples, tf.int64)

num_test_samples =mnist_info.splits['test'].num_examples
num_test_samples = tf.cast(num_test_samples, tf.int64)

def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255.
    return image, label


scaled_train_and_validation_data = mnist_train.map(scale)

test_data = mnist_test.map(scale)

BUFFER_SIZE = 10000

shuffled_train_and_validation_data = scaled_train_and_validation_data.shuffle(BUFFER_SIZE)

validation_data = shuffled_train_and_validation_data.take(num_validation_samples)

train_data = shuffled_train_and_validation_data.skip(num_validation_samples)

BATCH_SIZE =100

train_data =train_data.batch(BATCH_SIZE)

validation_data = validation_data.batch(num_validation_samples)

test_data = test_data.batch(num_test_samples)

validation_input, validation_targets = next(iter(validation_data))


In [15]:
# model

# outline the model

input_size = 784
output_size = 10

model = Sequential()
model.add(Flatten(input_shape=(28,28,1)))
model.add(Dense(50, activation='relu'))
model.add(Dense(50, activation='relu'))
model.add(Dense(output_size, activation='softmax'))


In [16]:
model.summary()

Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 flatten (Flatten)           (None, 784)               0         
                                                                 
 dense (Dense)               (None, 50)                39250     
                                                                 
 dense_1 (Dense)             (None, 50)                2550      
                                                                 
 dense_2 (Dense)             (None, 10)                510       
                                                                 
Total params: 42310 (165.27 KB)
Trainable params: 42310 (165.27 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________


In [18]:
# compile model
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

In [19]:
# train the model
model.fit(train_data, epochs=5, validation_data=(validation_input, validation_targets), verbose=3)

Epoch 1/5












Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.src.callbacks.History at 0x29da7fe9cd0>

In [20]:
model.evaluate(test_data)



[0.11678825318813324, 0.9649999737739563]

In [21]:
model.evaluate(train_data)



[0.08339910954236984, 0.9753333330154419]

In [22]:
y_pred = model.predict(test_data)



In [23]:
[np.argmax(i) for i in y_pred]

[2,
 0,
 4,
 8,
 7,
 6,
 0,
 6,
 3,
 1,
 8,
 0,
 7,
 9,
 8,
 4,
 5,
 3,
 4,
 0,
 6,
 6,
 3,
 0,
 2,
 3,
 6,
 6,
 7,
 4,
 0,
 3,
 8,
 2,
 5,
 4,
 2,
 5,
 5,
 8,
 5,
 2,
 9,
 2,
 4,
 2,
 7,
 0,
 5,
 1,
 0,
 7,
 9,
 9,
 9,
 6,
 5,
 8,
 8,
 6,
 9,
 9,
 5,
 4,
 2,
 6,
 8,
 1,
 0,
 6,
 9,
 5,
 5,
 4,
 1,
 6,
 7,
 5,
 2,
 9,
 0,
 6,
 4,
 4,
 2,
 8,
 7,
 8,
 3,
 0,
 9,
 0,
 1,
 1,
 9,
 4,
 5,
 9,
 1,
 6,
 6,
 0,
 7,
 7,
 8,
 4,
 8,
 3,
 1,
 8,
 0,
 2,
 9,
 1,
 0,
 3,
 9,
 7,
 0,
 4,
 9,
 6,
 8,
 9,
 3,
 5,
 4,
 3,
 2,
 2,
 4,
 0,
 3,
 2,
 3,
 5,
 1,
 1,
 2,
 4,
 8,
 2,
 2,
 6,
 8,
 6,
 1,
 2,
 0,
 6,
 1,
 7,
 2,
 3,
 4,
 4,
 6,
 9,
 0,
 8,
 2,
 8,
 6,
 0,
 9,
 7,
 8,
 2,
 9,
 0,
 2,
 3,
 2,
 8,
 7,
 0,
 1,
 7,
 8,
 5,
 1,
 2,
 3,
 1,
 9,
 8,
 7,
 9,
 0,
 8,
 3,
 9,
 3,
 9,
 7,
 4,
 9,
 3,
 1,
 3,
 2,
 6,
 0,
 7,
 6,
 9,
 5,
 9,
 7,
 4,
 2,
 0,
 1,
 0,
 4,
 9,
 9,
 1,
 9,
 2,
 2,
 0,
 5,
 6,
 0,
 0,
 0,
 7,
 7,
 4,
 7,
 5,
 1,
 7,
 7,
 8,
 2,
 5,
 9,
 6,
 6,
 8,
 0,
 1,
 2,
 1,
 7,
 7,
 7,
 3,
