In [1]:
import tensorflow as tf
import math
import numpy as np
from tensorflow.keras.utils import to_categorical

In [2]:
input_num = 784
hidden_num = 200
output_num = 10

In [3]:
limit = math.sqrt(6/(input_num+hidden_num))
w1 = tf.Variable(np.random.uniform(-limit, limit, (input_num, hidden_num)).astype("float32"), name="w1")
limit = math.sqrt(6/(hidden_num+output_num))
w2 = tf.Variable(np.random.uniform(-limit, limit, (hidden_num, output_num)).astype("float32"), name="w2")
b1 = tf.Variable(np.zeros((hidden_num), dtype="float32"), name="b1")
b2 = tf.Variable(np.zeros((output_num), dtype="float32"), name="b2")
params = [b1, b2, w1, w2]

In [4]:
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

x_train = x_train.reshape(-1, 784).astype("float32")
x_test = x_test.reshape(-1, 784).astype("float32")
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)

train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(32)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)

In [5]:
# input: (N, i)
# return: (N, j, k)
@tf.function
def signal_all(input):
    return tf.expand_dims(b1, 1) + w2 + tf.expand_dims(tf.matmul(input, w1), 2)

In [6]:
# input: (N, j)
# return: (N, j)
#@tf.function
#def activation(input):
#    return tf.where(
#        tf.math.abs(input) < 1e-5,
#        np.log(2).astype("float32") - input**2/6 + input**4/180,
#        tf.math.log(2*tf.math.sinh(input)/input)
#    )

@tf.function
def activation(input):
    return tf.math.softplus(input)

In [7]:
# input: (N, i)
# return: (N, k)
@tf.function
def probability(input):
    energies = b2 + tf.reduce_sum(activation(signal_all(input)), 1)
    max_energies = tf.reduce_max(energies, axis=1, keepdims=True)
    return tf.nn.softmax(energies-max_energies)

In [8]:
# input: (N, i)
# output: (N, k)
# return: (1)
@tf.function
def negative_log_likelihood(probs, labels):
    single_prob = tf.reduce_sum(probs * labels, 1)
    return -tf.reduce_mean(tf.math.log(single_prob))

In [9]:
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy')
test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.CategoricalAccuracy(name='test_accuracy')

In [10]:
@tf.function
def train(input, labels, opt):
    with tf.GradientTape() as tape:
        tape.watch(params)
        predict_probs = probability(input)
        loss = negative_log_likelihood(predict_probs, labels)
    grads = tape.gradient(loss, params)
    opt.apply_gradients(zip(grads, params))
    
    train_loss(loss)
    train_accuracy(labels, predict_probs)

In [11]:
@tf.function
def test(input, labels):
    predict_probs = probability(input)
    loss = negative_log_likelihood(predict_probs, labels)
    test_loss(loss)
    test_accuracy(labels, predict_probs)

In [12]:
EPOCHS = 20
optimizer = tf.keras.optimizers.Adam()
for epoch in range(EPOCHS):
  for images, labels in train_ds:
    train(images, labels, optimizer)

  for test_images, test_labels in test_ds:
    test(test_images, test_labels)

  template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'
  print (template.format(epoch+1,
                         train_loss.result(),
                         train_accuracy.result()*100,
                         test_loss.result(),
                         test_accuracy.result()*100))
  
  # 次のエポック用にメトリクスをリセット
  train_loss.reset_states()
  train_accuracy.reset_states()
  test_loss.reset_states()
  test_accuracy.reset_states()

Epoch 1, Loss: 0.3761788308620453, Accuracy: 89.84833526611328, Test Loss: 0.22725063562393188, Test Accuracy: 93.4000015258789
Epoch 2, Loss: 0.18691329658031464, Accuracy: 94.62000274658203, Test Loss: 0.1530020833015442, Test Accuracy: 95.42000579833984
Epoch 3, Loss: 0.13151738047599792, Accuracy: 96.13500213623047, Test Loss: 0.11788664758205414, Test Accuracy: 96.51000213623047
Epoch 4, Loss: 0.09818290174007416, Accuracy: 97.16500091552734, Test Loss: 0.10213764756917953, Test Accuracy: 96.93000030517578
Epoch 5, Loss: 0.07615720480680466, Accuracy: 97.7933349609375, Test Loss: 0.08914702385663986, Test Accuracy: 97.1500015258789
Epoch 6, Loss: 0.06026115268468857, Accuracy: 98.25, Test Loss: 0.078694187104702, Test Accuracy: 97.44999694824219
Epoch 7, Loss: 0.04831394553184509, Accuracy: 98.64166259765625, Test Loss: 0.07368049025535583, Test Accuracy: 97.69999694824219
Epoch 8, Loss: 0.03859495744109154, Accuracy: 98.94667053222656, Test Loss: 0.07162512838840485, Test Accurac