## Follow the [Guide](https://roberttlange.github.io/posts/2020/03/blog-post-10/)

### vmap demo

In [32]:
import numpy as np
import jax.numpy as jnp
from jax import grad, jit, vmap, value_and_grad
from jax import random

# Generate key which is used to generate random numbers
key = random.PRNGKey(1)

batch_dim = 32
feature_dim = 100
hidden_dim = 512

# Generate a batch of inputs
X = random.normal(key, (batch_dim, feature_dim))

# Generate Gaussian weights and biases
params = [random.normal(key, (hidden_dim, feature_dim)),
          random.normal(key, (hidden_dim,))]

W = params[0]
print(W.shape)
print(X.shape)
b = params[1]
print(b.shape)

def ReLU(x):
    """ Rectified Linear Unit (ReLU) activation function """
    return jnp.maximum(0, x)

def ReLU_Layer(W, x, b):
    return ReLU(jnp.dot(W, x) + b)

def vmap_ReLU_Layer(func):
    return jit( vmap(func, in_axes=(None, 0, None), out_axes=(0)) )
print("dot product shape")
print(jnp.dot(W, X[0]).shape)
print((jnp.dot(W, X[0]) + b).shape)

relu = vmap_ReLU_Layer(ReLU_Layer)
result = relu(W, X, b)
print(result.shape)

(512, 100)
(32, 100)
(512,)
dot product shape
(512,)
(512,)
(32, 512)


In [31]:
## test
def relu_layer(params, x):
    """ Simple ReLu layer for single sample """
    return ReLU(np.dot(params[0], x) + params[1])

def batch_version_relu_layer(params, x):
    """ Error prone batch version """
    return ReLU(np.dot(X, params[0].T) + params[1])

def vmap_relu_layer(params, x):
    """ vmap version of the ReLU layer """
    return jit(vmap(relu_layer, in_axes=(None, 0), out_axes=0))

out = jnp.stack([relu_layer(params, X[i, :]) for i in range(X.shape[0])])
out = batch_version_relu_layer(params, X)
out = vmap_relu_layer(params, X)