In [1]:
from jax import random
import jax.numpy as jnp
from jax.nn import swish

In [None]:
LAYER_SIZE = [28 * 28, 512, 10]
PARAM_SCALE = 0.01


def init_network_params(sizes, key=random.PRNGKey(0), scale=1e-2):

    def random_layer_params(m, n, key, scale):
        w_key, b_key = random.split(key)
        return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

    keys = random.split(key, len(sizes))
    return [
        random_layer_params(m, n, k, scale)
        for m, n, k in zip(sizes[:-1], sizes[1:], keys)
    ]


params = init_network_params(LAYER_SIZE, scale=PARAM_SCALE)


def predict(params, image):
    activations = image
    for w, b in params[:-1]:
        outputs = jnp.dot(w, activations) + b
        activations = swish(outputs)
    final_w, final_b = params[-1]
    logits = jnp.dot(final_w, activations) + final_b
    return logits