In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
from jax.scipy.special import logsumexp

In [2]:
key = random.PRNGKey(0)



In [3]:
def init_network_params():
    scale = 0.1
    
    w_key, b_key = random.split(key)
    w1 = scale * random.normal(w_key, (784, 512))
    b1 = scale * random.normal(b_key, (512,))

    w_key, b_key = random.split(key)
    w2 = scale * random.normal(w_key, (512, 256))
    b2 = scale * random.normal(b_key, (256,))

    w_key, b_key = random.split(key)
    w3 = scale * random.normal(w_key, (256, 10))
    b3 = scale * random.normal(b_key, (10,))
    
    return [
        (w1, b1),
        (w2, b2),
        (w3, b3)
    ]


In [4]:
params = init_network_params()

In [5]:
for i, param in enumerate(params):
    print(i, param[0].shape, param[1].shape)

0 (784, 512) (512,)
1 (512, 256) (256,)
2 (256, 10) (10,)


In [6]:
def relu(x):
    return jnp.maximum(0, x)


scale = 0.1
    
w_key, b_key = random.split(key)
w1 = scale * random.normal(w_key, (784, 512))
b1 = scale * random.normal(b_key, (512,))

w_key, b_key = random.split(key)
w2 = scale * random.normal(w_key, (512, 256))
b2 = scale * random.normal(b_key, (256,))

w_key, b_key = random.split(key)
w3 = scale * random.normal(w_key, (256, 10))
b3 = scale * random.normal(b_key, (10,))


def forward(image):
    x = image
    
    z1 = jnp.dot(w1.T, x) + b1
    a1 = relu(z1)
    
    z2 = jnp.dot(w2.T, a1) + b2
    a2 = relu(z2)
    
    logits = jnp.dot(w3.T, a2) + b3
    logprobs = logits - logsumexp(logits)
    return logprobs
    

In [7]:
random_flattened_image = random.normal(random.PRNGKey(1), (28 * 28,))
random_flattened_images = random.normal(random.PRNGKey(1), (3, 28 * 28))

In [8]:
preds = forward(random_flattened_image)
preds.shape

(10,)

In [9]:
forward_batch = vmap(forward, in_axes=(0,))

In [10]:
batched_preds = forward_batch(random_flattened_images)
batched_preds.shape

(3, 10)