## Follow the [Guide](https://roberttlange.github.io/posts/2020/03/blog-post-10/)

### vmap demo

In [None]:
import numpy as np
import jax.numpy as jnp
from jax import grad, jit, vmap, value_and_grad
from jax import random
import jax

# Generate key which is used to generate random numbers
key = random.PRNGKey(1)

batch_dim = 32
feature_dim = 100
hidden_dim = 512

# Generate a batch of inputs
X = random.normal(key, (batch_dim, feature_dim))

# Generate Gaussian weights and biases
params_old = [random.normal(key, (hidden_dim, feature_dim)),
          random.normal(key, (hidden_dim,))]

W = params_old[0]
print(W.shape)
print(X.shape)
b = params_old[1]
print(b.shape)

def ReLU(x):
    """ Rectified Linear Unit (ReLU) activation function """
    return jnp.maximum(0, x)

def ReLU_Layer(W, x, b):
    return ReLU(jnp.dot(W, x) + b)

def vmap_ReLU_Layer(func):
    return jit( vmap(func, in_axes=(None, 0, None), out_axes=(0)) )
print("dot product shape")
print(jnp.dot(W, X[0]).shape)
print((jnp.dot(W, X[0]) + b).shape)

relu = vmap_ReLU_Layer(ReLU_Layer)
result = relu(W, X, b)
print(result.shape)

In [None]:
## test
def relu_layer(params, x):
    """ Simple ReLu layer for single sample """
    return ReLU(np.dot(params[0], x) + params[1])

def batch_version_relu_layer(params, x):
    """ Error prone batch version """
    return ReLU(np.dot(X, params[0].T) + params[1])

def vmap_relu_layer(params, x):
    """ vmap version of the ReLU layer """
    return jit(vmap(relu_layer, in_axes=(None, 0), out_axes=0))

out = jnp.stack([relu_layer(params, X[i, :]) for i in range(X.shape[0])])
out = batch_version_relu_layer(params, X)
out = vmap_relu_layer(params, X)

### Jax [doc](https://jax.readthedocs.io/en/latest/notebooks/neural_network_with_tfds_data.html)

In [None]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

## Hyperparameters

In [None]:
# A helper function to randomly initialize weights and biases
# for a dense neural network layer
def random_layer_params(m, n, key, scale=1e-2):
    w_key, b_key = random.split(key)
    return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
    keys = random.split(key, len(sizes))
    return [random_layer_params(m, n ,k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

layer_sizes = [784, 512, 512, 10]
step_size = 0.01
num_epochs = 10
batch_size = 128
n_targets = 10
keys = random.PRNGKey(0)
params = init_network_params(layer_sizes, keys)
print(layer_sizes[:-1])
print(layer_sizes[1:])
print(keys)

for m, n, k in zip(layer_sizes[:-1], layer_sizes[1:], keys):
    print(m)
    print(n)
    print(k)

## Auto-batching predictions

In [None]:
import numpy as np

def relu(x):
    return jnp.maximum(0, x)

def predict(params, img):
    # per-example predictions
    inputs = img
    for w, b in params[:-1]:
        outputs = relu(jnp.dot(w, inputs) + b)
        inputs = outputs
    
    out_w, out_b = params[-1]
    logits = jnp.dot(out_w, inputs) + out_b
    return jax.nn.softmax(logits)

In [None]:
# This works on single examples
random_flattened_img = random.normal(random.PRNGKey(1), (28 * 28, ))
preds = predict(params, random_flattened_img)
print(preds)
print(jnp.sum(preds))

In [None]:
# This works on single examples
random_flattened_image = random.normal(random.PRNGKey(1), (28 * 28,))
preds = predict(params, random_flattened_image)
print(preds.shape)

### Doesn't work with batched inputs

In [None]:
# Doesn't work with a batch
random_flattened_images = random.normal(random.PRNGKey(1), (10, 28 * 28))
try:
  preds = predict(params, random_flattened_images)
except TypeError:
  print('Invalid shapes!')

In [None]:
batched_predict = jax.jit(jax.vmap(predict, in_axes=(None, 0), out_axes=0))
batched_inputs = random.normal(random.PRNGKey(1), (12, 28 * 28))
print(batched_inputs.shape)
print(batched_inputs[:].shape)
print(batched_inputs[:, None].shape)
batched_pred = batched_predict(params, batched_inputs)
print(np.array([1,2]).shape)
print(jnp.arange(2).shape)
print(batched_pred.shape)

## Utility and loss functions

## Data Loading with `tensorflow/datasets`

In [None]:
def one_hot(x, k, dtype=jnp.float32):
    """Create a one-hot encoding of x of size k."""
    return jnp.array(x[:, None] == jnp.arange(k), dtype)

def accuracy(params, images, targets):
    target_class = jnp.argmax(targets, axis=1)
    predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
    return jnp.mean(predicted_class == target_class)

def loss(params, images, targets):
    preds = batched_predict(params, images)
    return -jnp.mean(preds * targets)

# my own loss function
def cross_entropy_loss(output, target):
     return -jnp.log(output[target])

def batched_cross_entropy_loss(cross_entropy_loss=cross_entropy_loss):
    return jit(vmap(cross_entropy_loss, in_axes=(0, 0), out_axes=(0)))

my_loss = batched_cross_entropy_loss()
test_in = jnp.array([[0.1, 0.9], [0.5, 0.5]])
test_target = jnp.array([0, 1])
test_out = my_loss(test_in, test_target)
print(test_out)

def myloss_func(params, images, targets):
    preds = batched_predict(params, images)
    print(preds.shape)
    print(targets.shape)
    return my_loss(preds, targets)

@jit
def update(params, x, y):
    grads = grad(loss)(params, x, y)
    return [(w - step_size * dw, b - step_size * db)
          for (w, b), (dw, db) in zip(params, grads)]

In [None]:
import tensorflow as tf
# Ensure TF does not see GPU and grab all GPU memory.
tf.config.set_visible_devices([], device_type='GPU')

import tensorflow_datasets as tfds

data_dir = '/tmp/tfds'

# Fetch full datasets for evaluation
# tfds.load returns tf.Tensors (or tf.data.Datasets if batch_size != -1)
# You can convert them to NumPy arrays (or iterables of NumPy arrays) with tfds.dataset_as_numpy
mnist_data, info = tfds.load(name="mnist", batch_size=-1, data_dir=data_dir, with_info=True)
mnist_data = tfds.as_numpy(mnist_data)
train_data, test_data = mnist_data['train'], mnist_data['test']
num_labels = info.features['label'].num_classes
h, w, c = info.features['image'].shape
num_pixels = h * w * c

# Full train set
train_images, train_labels = train_data['image'], train_data['label']
train_images = jnp.reshape(train_images, (len(train_images), num_pixels))
train_labels = one_hot(train_labels, num_labels)

# Full test set
test_images, test_labels = test_data['image'], test_data['label']
test_images = jnp.reshape(test_images, (len(test_images), num_pixels))
test_labels = one_hot(test_labels, num_labels)

In [None]:
print('Train:', train_images.shape, train_labels.shape)
print('Test:', test_images.shape, test_labels.shape)

## Training Loop

In [None]:
import time

def get_train_batches():
  # as_supervised=True gives us the (image, label) as a tuple instead of a dict
  ds = tfds.load(name='mnist', split='train', as_supervised=True, data_dir=data_dir)
  # You can build up an arbitrary tf.data input pipeline
  ds = ds.batch(batch_size).prefetch(1)
  # tfds.dataset_as_numpy converts the tf.data.Dataset into an iterable of NumPy arrays
  return tfds.as_numpy(ds)

for epoch in range(num_epochs):
  start_time = time.time()
  for x, y in get_train_batches():
    x = jnp.reshape(x, (len(x), num_pixels))
    y = one_hot(y, num_labels)
    params = update(params, x, y)
  epoch_time = time.time() - start_time

  train_acc = accuracy(params, train_images, train_labels)
  test_acc = accuracy(params, test_images, test_labels)
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))