In [2]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
from jax.scipy.special import logsumexp

import numpy as np
import time

In [3]:
# 初始化一层神经网络 weight 和 bias
def random_layer_params(m, n, key, scale=1e-2):
    w_key, b_key = random.split(key) 
    return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

# 初始化 MLP 所有层的 weight 和 bias, sizes 是一个 list，包含每一层的神经元个数
def init_network_params(sizes, key, scale=1e-2):
    keys = random.split(key, len(sizes))
    return [random_layer_params(m, n, k, scale) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

layer_size = [784, 512, 512, 10]
step_size = 0.01
num_epochs = 10
batch_size = 128
n_targets = 10
params = init_network_params(layer_size, random.PRNGKey(0))

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [4]:
# vmap auto-batching
def relu(x):
    return jnp.maximum(0, x)

def predict(params, image):
    activations = image
    # 手动进行 forward propagation
    for w, b in params[:-1]:
        outputs = jnp.dot(w, activations) + b
        activations = relu(outputs)
    final_w, final_b = params[-1]
    logits = jnp.dot(final_w, activations) + final_b
    return logits - logsumexp(logits)

In [5]:
# test predict for single image
random_flattened_image = random.normal(random.PRNGKey(1), (28*28,))
preds = predict(params, random_flattened_image)
print(preds.shape)

(10,)


In [6]:
# 使用 vmap 自动 batch
random_flattened_images = random.normal(random.PRNGKey(1), (10, 28*28))
batched_predict = vmap(predict, in_axes=(None, 0))
batched_preds = batched_predict(params, random_flattened_images)
print(batched_preds.shape)

(10, 10)


In [7]:
# 效用 与 损失函数
def one_hot(x, k, dtype=jnp.float32):
    """Create a one-hot encoding of x of size k."""
    return jnp.array(x[:, None] == jnp.arange(k), dtype)

def accuracy(params, images, targets):
    """标签对比"""
    target_class = jnp.argmax(targets, axis=1)
    predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
    return jnp.mean(predicted_class == target_class)

def loss(params, images, targets):
    preds = batched_predict(params, images)
    return -jnp.mean(preds * targets)

@jit
def update(params, x, y):
    """计算梯度并更新参数"""
    grads = grad(loss)(params, x, y)
    return [(w - step_size * dw, b - step_size * db)
            for (w, b), (dw, db) in zip(params, grads)]

In [46]:
data_dir = 'dataset/mnist/'
with open(data_dir + 'train-images-idx3-ubyte', 'rb') as f:
    train_images = np.frombuffer(f.read(), np.uint8, offset=16)
train_images = train_images.reshape(-1, 28*28)
train_images = jnp.reshape(train_images, (train_images.shape[0], -1))

with open(data_dir + 'train-labels-idx1-ubyte', 'rb') as f:
    train_labels = np.frombuffer(f.read(), np.uint8, offset=8)
train_labels = train_labels.reshape(-1, )
train_labels = one_hot(train_labels, n_targets)

with open(data_dir + 't10k-images-idx3-ubyte', 'rb') as f:
    test_images = np.frombuffer(f.read(), np.uint8, offset=16)
test_images = test_images.reshape(-1, 28*28)
test_images = jnp.reshape(test_images, (test_images.shape[0], -1))

with open(data_dir + 't10k-labels-idx1-ubyte', 'rb') as f:
    test_labels = np.frombuffer(f.read(), np.uint8, offset=8)
test_labels = test_labels.reshape(-1, )
test_labels = one_hot(test_labels, n_targets)

print('Train:', train_images.shape, train_labels.shape)
print('Test:', test_images.shape, test_labels.shape)

Train: (60000, 784) (60000, 10)
Test: (10000, 784) (10000, 10)


In [56]:
def get_train_batch():
    idx = np.random.randint(0, train_images.shape[0], size=batch_size)
    return train_images[idx, :], train_labels[idx, :]
    

for epoch in range(num_epochs):
    start_time = time.time()
    for _ in range(train_images.shape[0] // batch_size):
        params = update(params, *get_train_batch())
    epoch_time = time.time() - start_time

    train_acc = accuracy(params, train_images, train_labels)
    test_acc = accuracy(params, test_images, test_labels)
    print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
    print("Training set accuracy {}".format(train_acc))
    print("Test set accuracy {}".format(test_acc))

Epoch 0 in 2.24 sec
Training set accuracy 0.9804999828338623
Test set accuracy 0.9723999500274658
Epoch 1 in 2.09 sec
Training set accuracy 0.9834666848182678
Test set accuracy 0.9727999567985535
Epoch 2 in 2.25 sec
Training set accuracy 0.9843167066574097
Test set accuracy 0.9734999537467957
Epoch 3 in 2.86 sec
Training set accuracy 0.985883355140686
Test set accuracy 0.973800003528595
Epoch 4 in 2.72 sec
Training set accuracy 0.9869666695594788
Test set accuracy 0.9732999801635742
Epoch 5 in 2.29 sec
Training set accuracy 0.985450029373169
Test set accuracy 0.9715999960899353
Epoch 6 in 2.35 sec
Training set accuracy 0.9890000224113464
Test set accuracy 0.9754999876022339
Epoch 7 in 2.55 sec
Training set accuracy 0.9906833171844482
Test set accuracy 0.9770999550819397
Epoch 8 in 2.54 sec
Training set accuracy 0.9909666776657104
Test set accuracy 0.9770999550819397
Epoch 9 in 2.36 sec
Training set accuracy 0.9924833178520203
Test set accuracy 0.9776999950408936
