In [39]:
import tensorflow as tf

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 784).astype("float32") / 255.0
x_test = x_test.reshape(-1, 784).astype("float32") / 255.0

y_train = tf.cast(y_train, tf.int32)
y_test = tf.cast(y_test, tf.int32)

In [40]:
learning_rate = 0.001
epochs = 10
batch_size = 200

In [41]:
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = (
    train_dataset.shuffle(buffer_size=1000)
    .batch(batch_size)
    .prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
)

test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_dataset = test_dataset.batch(batch_size)

In [42]:
n_input = 784
n_layer_1 = 30
n_layer_2 = 30
n_output = 10

In [43]:
w1 = tf.Variable(tf.random.normal([n_input, n_layer_1], stddev=0.05))
b1 = tf.Variable(tf.zeros([n_layer_1]))

w2 = tf.Variable(tf.random.normal([n_layer_1, n_layer_2], stddev=0.05))
b2 = tf.Variable(tf.zeros([n_layer_2]))

w_out = tf.Variable(tf.random.normal([n_layer_2, n_output], stddev=0.05))
b_out = tf.Variable(tf.zeros([n_output]))

In [44]:
optimizer = tf.optimizers.Adam(learning_rate=learning_rate)

In [45]:
for epoch in range(epochs):
    for batch_x, batch_y in train_dataset:
        with tf.GradientTape() as tape:
            layer_1 = tf.nn.relu(tf.add(tf.matmul(batch_x, w1), b1))
            layer_2 = tf.nn.relu(tf.add(tf.matmul(layer_1, w2), b2))
            output = tf.add(tf.matmul(layer_2, w_out), b_out)

            loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=batch_y, logits=output))

        gradients = tape.gradient(loss, [w1, b1, w2, b2, w_out, b_out])
        optimizer.apply_gradients(zip(gradients, [w1, b1, w2, b2, w_out, b_out]))

    print(f'Epoch {epoch+1}, Loss: {loss.numpy()}')

Epoch 1, Loss: 0.20276160538196564
Epoch 2, Loss: 0.2863790988922119
Epoch 3, Loss: 0.21349841356277466
Epoch 4, Loss: 0.12812009453773499
Epoch 5, Loss: 0.13810642063617706
Epoch 6, Loss: 0.09045713394880295
Epoch 7, Loss: 0.09209824353456497
Epoch 8, Loss: 0.10880052298307419
Epoch 9, Loss: 0.0715223029255867
Epoch 10, Loss: 0.050713252276182175
