In [1]:
import numpy as np
import tensorflow as tf

In [2]:
# Parameters
learning_rate = 0.001
training_epochs = 6
batch_size = 600

# Import MNIST data
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

train_dataset = (
    tf.data.Dataset.from_tensor_slices((tf.reshape(x_train, [-1, 784]), y_train))
    .batch(batch_size)
    .shuffle(1000)
)

train_dataset = (
    train_dataset.map(lambda x, y:
                      (tf.divide(tf.cast(x, tf.float32), 255.0),
                       tf.reshape(tf.one_hot(y, 10), (-1, 10))))
)

In [3]:
# Set model weights
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))

In [4]:
# Construct model
model = lambda x: tf.nn.softmax(tf.matmul(x, W) + b) # Softmax
# Minimize error using cross entropy
compute_loss = lambda true, pred: tf.reduce_mean(tf.reduce_sum(tf.losses.binary_crossentropy(true, pred), axis=-1))
# caculate accuracy
compute_accuracy = lambda true, pred: tf.reduce_mean(tf.keras.metrics.categorical_accuracy(true, pred))
# Gradient Descent
optimizer = tf.optimizers.Adam(learning_rate)

for epoch in range(training_epochs):
    for i, (x_, y_) in enumerate(train_dataset):
        with tf.GradientTape() as tape:
            pred = model(x_)
            loss = compute_loss(y_, pred)
        acc = compute_accuracy(y_, pred)
        grads = tape.gradient(loss, [W, b])
        optimizer.apply_gradients(zip(grads, [W, b]))
        print("=> loss %.2f acc %.2f" %(loss.numpy(), acc.numpy()))

=> loss 195.05 acc 0.10
=> loss 191.49 acc 0.57
=> loss 188.84 acc 0.53
=> loss 186.42 acc 0.57
=> loss 182.66 acc 0.57
=> loss 178.83 acc 0.67
=> loss 178.17 acc 0.61
=> loss 174.41 acc 0.62
=> loss 171.08 acc 0.67
=> loss 169.72 acc 0.71
=> loss 165.83 acc 0.66
=> loss 163.73 acc 0.71
=> loss 166.10 acc 0.63
=> loss 163.47 acc 0.68
=> loss 155.01 acc 0.72
=> loss 151.60 acc 0.70
=> loss 153.95 acc 0.72
=> loss 149.91 acc 0.69
=> loss 143.68 acc 0.75
=> loss 144.36 acc 0.74
=> loss 146.78 acc 0.70
=> loss 138.88 acc 0.75
=> loss 143.94 acc 0.74
=> loss 134.62 acc 0.76
=> loss 131.31 acc 0.78
=> loss 130.84 acc 0.77
=> loss 128.80 acc 0.80
=> loss 126.33 acc 0.78
=> loss 122.29 acc 0.77
=> loss 126.45 acc 0.79
=> loss 120.56 acc 0.81
=> loss 120.24 acc 0.77
=> loss 123.38 acc 0.74
=> loss 120.65 acc 0.76
=> loss 116.62 acc 0.82
=> loss 111.72 acc 0.81
=> loss 103.49 acc 0.85
=> loss 104.84 acc 0.79
=> loss 109.73 acc 0.77
=> loss 105.46 acc 0.80
=> loss 107.44 acc 0.80
=> loss 109.96 a

=> loss 46.59 acc 0.87
=> loss 37.51 acc 0.90
=> loss 40.45 acc 0.89
=> loss 36.56 acc 0.91
=> loss 44.61 acc 0.89
=> loss 36.38 acc 0.92
=> loss 39.42 acc 0.90
=> loss 31.19 acc 0.92
=> loss 33.75 acc 0.92
=> loss 33.85 acc 0.92
=> loss 38.62 acc 0.91
=> loss 49.52 acc 0.86
=> loss 41.80 acc 0.89
=> loss 42.63 acc 0.88
=> loss 32.95 acc 0.91
=> loss 41.53 acc 0.88
=> loss 36.28 acc 0.91
=> loss 38.24 acc 0.90
=> loss 38.05 acc 0.90
=> loss 30.17 acc 0.92
=> loss 49.28 acc 0.85
=> loss 42.25 acc 0.88
=> loss 37.91 acc 0.88
=> loss 36.97 acc 0.90
=> loss 36.47 acc 0.92
=> loss 40.56 acc 0.89
=> loss 38.35 acc 0.90
=> loss 49.30 acc 0.86
=> loss 44.14 acc 0.88
=> loss 34.42 acc 0.90
=> loss 44.10 acc 0.88
=> loss 47.79 acc 0.87
=> loss 33.56 acc 0.91
=> loss 46.88 acc 0.87
=> loss 36.54 acc 0.90
=> loss 25.69 acc 0.94
=> loss 35.17 acc 0.92
=> loss 38.99 acc 0.90
=> loss 43.83 acc 0.88
=> loss 31.36 acc 0.91
=> loss 49.78 acc 0.86
=> loss 45.50 acc 0.88
=> loss 29.64 acc 0.93
=> loss 39.