In [1]:
import tensorflow as tf
import math
import numpy as np
from tensorflow.keras.utils import to_categorical

In [2]:
input_num = 784
hidden_num = 200
output_num = 10

In [3]:
limit = math.sqrt(6/(input_num+hidden_num))
w1 = tf.Variable(np.random.uniform(-limit, limit, (input_num, hidden_num)).astype("float32"), name="w1")
limit = math.sqrt(6/(hidden_num+output_num))
w2 = tf.Variable(np.random.uniform(-limit, limit, (hidden_num, output_num)).astype("float32"), name="w2")
b1 = tf.Variable(np.zeros((hidden_num), dtype="float32"), name="b1")
b2 = tf.Variable(np.zeros((output_num), dtype="float32"), name="b2")
params = [b1, b2, w1, w2]

In [4]:
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = (x_train / 255.0), x_test / 255.0

x_train = x_train.reshape(-1, 784).astype("float32")
x_test = x_test.reshape(-1, 784).astype("float32")
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)

train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(100)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(100)

In [5]:
# input: (N, i)
# return: (N, j, k)
@tf.function
def signal_all(input):
    return tf.expand_dims(b1, 1) + w2 + tf.expand_dims(tf.matmul(input, w1), 2)

$\alpha_j = c_j + U_jy + \sum_{i} \boldsymbol{W}_{j,i} \boldsymbol{x_i}$ とすると, $\sum_{h_j} \exp(h_j \alpha_j)$ は

- 中間素子 $h_j$ が $h_j \in \left\{ 0, 1 \right\}$ のとき

$\sum_{h_j} \exp(h_j \alpha_j) = \ln (1+\exp(x)) = \mathrm{softplus}(x)$

- 中間素子 $h_j$ が $h_j \in \left[ -1, +1 \right]$ のとき

$\sum_{h_j} \exp(h_j \alpha_j) = \ln \frac{2 \sinh(x)}{x}$

In [6]:
# input: (N, j)
# return: (N, j)
@tf.function
@tf.custom_gradient
def activation(input):
    ret = tf.where(
        tf.math.abs(input) < 1e-5,
        tf.math.log(2.) + input**2/6,
        tf.math.log(2*tf.math.sinh(input)/input)
    )
    @tf.function
    def activation_grad(dy):
        return tf.where(
            input == 0,
            0.,
            dy * (1/tf.math.tanh(input) - 1/input)
        )
    return ret, activation_grad

#@tf.function
#def activation(input):
#    return tf.math.softplus(input)

In [7]:
# input: (N, i)
# return: (N, k)
@tf.function
def probability(input):
    sig = tf.debugging.assert_all_finite( signal_all(input), "signal_all")
    act = tf.debugging.assert_all_finite( activation(sig), "activation")
    energies = b2 + tf.reduce_sum(act, 1)
    max_energies = tf.reduce_max(energies, axis=1, keepdims=True)
    return tf.nn.softmax(energies-max_energies)

In [8]:
# input: (N, i)
# output: (N, k)
# return: (1)
@tf.function
def negative_log_likelihood(probs, labels):
    single_prob = tf.reduce_sum(probs * labels, 1)
    return -tf.reduce_mean(tf.math.log(single_prob))

In [9]:
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy')
test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.CategoricalAccuracy(name='test_accuracy')

In [10]:
@tf.function
def train(input, labels, opt):
    with tf.GradientTape() as tape:
        tape.watch(params)
        predict_probs = probability(input)
        loss = negative_log_likelihood(predict_probs, labels)
    grads = tape.gradient(loss, params)
    for g in grads:
        tf.debugging.assert_all_finite(g, "gradient")
    opt.apply_gradients(zip(grads, params))
    
    train_loss(loss)
    train_accuracy(labels, predict_probs)

In [11]:
@tf.function
def test(input, labels):
    predict_probs = probability(input)
    loss = negative_log_likelihood(predict_probs, labels)
    test_loss(loss)
    test_accuracy(labels, predict_probs)

In [12]:
EPOCHS = 20
optimizer = tf.keras.optimizers.Adam()
for epoch in range(EPOCHS):
  for images, labels in train_ds:
    train(images, labels, optimizer)

  for test_images, test_labels in test_ds:
    test(test_images, test_labels)

  template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'
  print (template.format(epoch+1,
                         train_loss.result(),
                         train_accuracy.result()*100,
                         test_loss.result(),
                         test_accuracy.result()*100))
  
  # 次のエポック用にメトリクスをリセット
  train_loss.reset_states()
  train_accuracy.reset_states()
  test_loss.reset_states()
  test_accuracy.reset_states()

Epoch 1, Loss: 0.42067691683769226, Accuracy: 88.81999969482422, Test Loss: 0.26803818345069885, Test Accuracy: 92.1199951171875
Epoch 2, Loss: 0.24803711473941803, Accuracy: 92.87833404541016, Test Loss: 0.2197519838809967, Test Accuracy: 93.63999938964844
Epoch 3, Loss: 0.19857341051101685, Accuracy: 94.31499481201172, Test Loss: 0.17811623215675354, Test Accuracy: 94.7300033569336
Epoch 4, Loss: 0.16230112314224243, Accuracy: 95.3800048828125, Test Loss: 0.15066532790660858, Test Accuracy: 95.47000122070312
Epoch 5, Loss: 0.13512328267097473, Accuracy: 96.15166473388672, Test Loss: 0.13194890320301056, Test Accuracy: 96.17000579833984
Epoch 6, Loss: 0.11457814276218414, Accuracy: 96.71666717529297, Test Loss: 0.11781653761863708, Test Accuracy: 96.6300048828125
Epoch 7, Loss: 0.09817301481962204, Accuracy: 97.15166473388672, Test Loss: 0.11145833134651184, Test Accuracy: 96.62000274658203
Epoch 8, Loss: 0.08391623198986053, Accuracy: 97.61166381835938, Test Loss: 0.09947498142719269