# 
From 
https://jax.readthedocs.io/en/latest/notebooks/neural_network_with_tfds_data.html

In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

In [2]:
def random_layer_params(m, n, key, scale=1e-2):
    w_key, b_key = random.split(key)
    return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

In [3]:
def init_network_params(sizes, key):
    keys = random.split(key, len(sizes))
    return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

In [4]:
n_targets = 10
step_size = 0.01
num_epochs = 10
batch_size = 128

layer_sizes = [784, 512, 512, n_targets]
params = init_network_params(layer_sizes, random.PRNGKey(0))

I0000 00:00:1701557099.478758   18039 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [5]:
from jax.scipy.special import logsumexp

In [6]:
def relu(x):
    return jnp.maximum(0, x)

def predict(params, image):
    activations = image
    for w, b in params[:-1]:
        outputs = jnp.dot(w, activations) + b
        activations = relu(outputs)
    final_w, final_b = params[-1]
    logits = jnp.dot(final_w, activations) + final_b
    return logits - logsumexp(logits)
    

In [7]:
random_image = random.normal(random.PRNGKey(1), (10, 28*28,))
preds = predict(params, random_image)
preds.shape

TypeError: dot_general requires contracting dimensions to have the same shape, got (784,) and (10,).

In [8]:
batched_predict = vmap(predict, in_axes=(None, 0))
batched_predict(params, random.normal(random.PRNGKey(1), (10, 28*28))).shape

(10, 10)

In [9]:
def one_hot(x, k, dtype=jnp.float32):
    return jnp.array(x[:, None] == jnp.arange(k), dtype)

def accuracy(params, images, targets):
    target_class = jnp.argmax(targets, axis=1)
    predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
    return jnp.mean(predicted_class == target_class)

def loss(params, images, targets):
    preds = batched_predict(params, images)
    return -jnp.mean(preds * targets)

In [10]:
@jit
def update(params, x, y):
    grads = grad(loss)(params, x, y)
    return [
        (w - step_size * dw, b - step_size * db)
        for (w, b), (dw, db) in zip(params, grads)]

In [11]:
import tensorflow as tf
# Ensure TF does not see GPU and grab all GPU memory.
tf.config.set_visible_devices([], device_type='GPU')

import tensorflow_datasets as tfds

data_dir = '/tmp/tfds'

# Fetch full datasets for evaluation
# tfds.load returns tf.Tensors (or tf.data.Datasets if batch_size != -1)
# You can convert them to NumPy arrays (or iterables of NumPy arrays) with tfds.dataset_as_numpy
mnist_data, info = tfds.load(name="mnist", batch_size=-1, data_dir=data_dir, with_info=True)
mnist_data = tfds.as_numpy(mnist_data)
train_data, test_data = mnist_data['train'], mnist_data['test']
num_labels = info.features['label'].num_classes
h, w, c = info.features['image'].shape
num_pixels = h * w * c

# Full train set
train_images, train_labels = train_data['image'], train_data['label']
train_images = jnp.reshape(train_images, (len(train_images), num_pixels))
train_labels = one_hot(train_labels, num_labels)

# Full test set
test_images, test_labels = test_data['image'], test_data['label']
test_images = jnp.reshape(test_images, (len(test_images), num_pixels))
test_labels = one_hot(test_labels, num_labels)

  from .autonotebook import tqdm as notebook_tqdm


[1mDownloading and preparing dataset mnist (11.06 MiB) to /tmp/tfds/mnist/1.0.0...[0m


Dl Completed...: 0 url [00:00, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s][A

Dl Completed...:   0%|                                                                          | 0/1 [00:00<?, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s][A

Dl Completed...:   0%|                                                                          | 0/2 [00:00<?, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s][A

Dl Completed...:   0%|                                                                          | 0/3 [00:00<?, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s][A

Dl Completed...:   0%|                                                                          | 0/4 [00:00<?, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s][A

Dl Completed...:   0%|                                                                          | 0/4 [00:00<?, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s][A

Dl Completed...:   0%|                                                                          | 0/4 [00:00<?, ? url/s]







Shuffling...:   0%|                                                                          | 0/10 [00:00<?, ? shard/s]

Instructions for updating:
Use eager execution and: 
`tf.data.TFRecordDataset(path)`


Instructions for updating:
Use eager execution and: 
`tf.data.TFRecordDataset(path)`

Reading...: 0 examples [00:00, ? examples/s][A
                                            [A
Writing...:   0%|                                                                       | 0/6000 [00:00<?, ? examples/s][A
                                                                                                                        [A
Reading...: 0 examples [00:00, ? examples/s][A
                                            [A
Writing...:   0%|                                                                       | 0/6000 [00:00<?, ? examples/s][A
                                                                                                                        [A
Reading...: 0 examples [00:00, ? examples/s][A
                                            [A
Writing...:   0%|                                                                       | 0/6000 [00:00<?, ? examples/s][A
      

[1mDataset mnist downloaded and prepared to /tmp/tfds/mnist/1.0.0. Subsequent calls will reuse this data.[0m
Instructions for updating:
Use `tf.data.Dataset.get_single_element()`.


Instructions for updating:
Use `tf.data.Dataset.get_single_element()`.


In [12]:
print('Train:', train_images.shape, train_labels.shape)
print('Test:', test_images.shape, test_labels.shape)

Train: (60000, 784) (60000, 10)
Test: (10000, 784) (10000, 10)


In [13]:
import time

def get_train_batches():
  # as_supervised=True gives us the (image, label) as a tuple instead of a dict
  ds = tfds.load(name='mnist', split='train', as_supervised=True, data_dir=data_dir)
  # You can build up an arbitrary tf.data input pipeline
  ds = ds.batch(batch_size).prefetch(1)
  # tfds.dataset_as_numpy converts the tf.data.Dataset into an iterable of NumPy arrays
  return tfds.as_numpy(ds)

for epoch in range(num_epochs):
  start_time = time.time()
  for x, y in get_train_batches():
    x = jnp.reshape(x, (len(x), num_pixels))
    y = one_hot(y, num_labels)
    params = update(params, x, y)
  epoch_time = time.time() - start_time

  train_acc = accuracy(params, train_images, train_labels)
  test_acc = accuracy(params, test_images, test_labels)
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))



Epoch 0 in 3.04 sec
Training set accuracy 0.9192333221435547
Test set accuracy 0.9192999601364136




Epoch 1 in 2.83 sec
Training set accuracy 0.9392333626747131
Test set accuracy 0.9382999539375305




Epoch 2 in 2.80 sec
Training set accuracy 0.9499500393867493
Test set accuracy 0.9484999775886536




Epoch 3 in 3.26 sec
Training set accuracy 0.9591000080108643
Test set accuracy 0.9563999772071838




Epoch 4 in 2.59 sec
Training set accuracy 0.9643666744232178
Test set accuracy 0.9608999490737915




Epoch 5 in 2.93 sec
Training set accuracy 0.9688000082969666
Test set accuracy 0.9637999534606934




Epoch 6 in 2.65 sec
Training set accuracy 0.9726499915122986
Test set accuracy 0.9664999842643738




Epoch 7 in 3.05 sec
Training set accuracy 0.9748333692550659
Test set accuracy 0.9663999676704407




Epoch 8 in 2.74 sec
Training set accuracy 0.9777333736419678
Test set accuracy 0.9695000052452087
Epoch 9 in 3.08 sec
Training set accuracy 0.9798499941825867
Test set accuracy 0.97079998254776
