In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

# Suppress warning and info messages from jax
import os  
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

In [15]:
# Random number generation

# Seed
key_x = random.PRNGKey(0)
key_y = random.PRNGKey(1)

# Create 2 n x n matrix of random numbers
n = 3
X = random.normal(key_x, (n, n), dtype=jnp.float32)
Y = random.normal(key_y, (n, n), dtype=jnp.float32)

print(X, "\n")
print(Y)

[[-0.3721109   0.26423115 -0.18252768]
 [-0.7368197   0.44973662 -0.1521442 ]
 [-0.67135346 -0.5908641   0.73168886]] 

[[ 0.690805   -0.48744103 -1.155789  ]
 [ 0.12108463  1.2010182  -0.5078766 ]
 [ 0.91568655  1.70968    -0.36749417]]


In [3]:
n = 3000
X = random.normal(key_x, (n, n), dtype=jnp.float32)
Y = random.normal(key_y, (n, n), dtype=jnp.float32)

# Time the execution of multiplying 2 matrices
# Jax by default is using asynchronous execution, block_until_ready() force synchronous
%timeit jnp.dot(X, Y).block_until_ready()

233 ms ± 3.35 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [4]:
# Using jit() to speed up functions
key = random.PRNGKey(42)

def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

# By default Jax tries to run on a GPU or TPU
# If not availible it falls back to the CPU
# jit() uses XLA (Accelerated Linear Algebra) for increased preformance
x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()

# just-in-time (JIT) compilation decorator, sequences of operations can be optimized together and run at once
# Not all JAX code can be JIT compiled, as it requires array shapes to be static & known at compile time
selu_jit = jit(selu)
%timeit selu_jit(x).block_until_ready()

2.09 ms ± 36.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
350 µs ± 58.2 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [5]:
# Taking derivatives (gradients)

# Define function
def logistic_func(x):
  return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

# Generate inputs
x_small = jnp.arange(3.)
print(x_small)

# Create gradient of the function
derivative_fn = grad(logistic_func)

# Calculate gradient
print("\nGradient:\n",derivative_fn(x_small))

[0. 1. 2.]

Gradient:
 [0.25       0.19661197 0.10499357]


In [6]:
# vectorization - converting a matrix into a vector 

mat = random.normal(key, (150, 100))
batched_x = random.normal(key, (10, 100))

def apply_matrix(v):
  return jnp.dot(mat, v)

@jit
def vmap_batched_apply_matrix(v_batched):
  return vmap(apply_matrix)(v_batched)

print('\nAuto-vectorized with vmap:')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()


Auto-vectorized with vmap:
22.7 µs ± 30.9 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [7]:
# Arrays
import numpy as np

# NumPy: mutable arrays
x = np.arange(10)
# can do x[0] = 1

# JAX: immutable arrays
x = jnp.arange(10)
# cannot do x[0] = 1

# instead can create a updated copy 
y = x.at[0].set(10)

print(x)
print(y)

[0 1 2 3 4 5 6 7 8 9]
[10  1  2  3  4  5  6  7  8  9]


In [19]:
# MLP Model

# A helper function to randomly initialize weights and biases
# for a dense neural network layer
def random_layer_params(m, n, key, scale=1e-2):
    w_key, b_key = random.split(key)
    return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
    keys = random.split(key, len(sizes))
    return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

# Hyperparameters
# layer_sizes = [784, 512, 512, 10]
layer_sizes = [784, 512, 512, 10]
step_size = 0.01
num_epochs = 10
batch_size = 128
n_targets = 10
params = init_network_params(layer_sizes, random.PRNGKey(0))

In [28]:
# Define relu and prediction function

from jax.scipy.special import logsumexp

def relu(x):
  return jnp.maximum(0, x)

def predict(params, image):
  # per-example predictions
  activations = image
  for w, b in params[:-1]:
    outputs = jnp.dot(w, activations) + b
    activations = relu(outputs)
  
  final_w, final_b = params[-1]
  logits = jnp.dot(final_w, activations) + final_b
  return logits - logsumexp(logits)

# Running on a single example image
random_flattened_image = random.normal(random.PRNGKey(1), (28 * 28,))
print(f"Image Shape: {random_flattened_image.shape}")

preds = predict(params, random_flattened_image)
print(preds.shape)

Image Shape: (784,)
(10,)


In [21]:
# Define batch of images - 10 images of 28x28
random_flattened_images = random.normal(random.PRNGKey(1), (10, 28 * 28))

# Make a batched version of the `predict` function
batched_predict = vmap(predict, in_axes=(None, 0))

# `batched_predict` has the same call signature as `predict`
batched_preds = batched_predict(params, random_flattened_images)
print(batched_preds.shape)

(10, 10)


In [22]:
# Define utility functions

# Create one hot encoding of the output
def one_hot(x, k, dtype=jnp.float32):
  """Create a one-hot encoding of x of size k."""
  return jnp.array(x[:, None] == jnp.arange(k), dtype)

# Test accuracy 
def accuracy(params, images, targets):
  target_class = jnp.argmax(targets, axis=1)
  predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
  return jnp.mean(predicted_class == target_class)

# Find loss
def loss(params, images, targets):
  preds = batched_predict(params, images)
  return -jnp.mean(preds * targets)

# Backpropagration
@jit
def update(params, x, y):
  grads = grad(loss)(params, x, y)
  return [(w - step_size * dw, b - step_size * db)
          for (w, b), (dw, db) in zip(params, grads)]

In [23]:
# Load and Create Datasets
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

import tensorflow as tf
import tensorflow_datasets as tfds

print(f"CUDA_VISIBLE_DEVICES = {os.environ['CUDA_VISIBLE_DEVICES']}")
print(f"TF_CPP_MIN_LOG_LEVEL = {os.environ['TF_CPP_MIN_LOG_LEVEL']}")

# List GPUs availible
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))
print("Num CPUs Available: ", len(tf.config.list_physical_devices('CPU')))

from tensorflow.python.client import device_lib
print("\n", device_lib.list_local_devices(), '\n')

data_dir = './datasets'

# Fetch full datasets for evaluation
# tfds.load returns tf.Tensors (or tf.data.Datasets if batch_size != -1)
# You can convert them to NumPy arrays (or iterables of NumPy arrays) with tfds.dataset_as_numpy
mnist_data, info = tfds.load(name="mnist", batch_size=-1, data_dir=data_dir, with_info=True)

mnist_data = tfds.as_numpy(mnist_data)
train_data, test_data = mnist_data['train'], mnist_data['test']
num_labels = info.features['label'].num_classes
h, w, c = info.features['image'].shape
num_pixels = h * w * c

# # Full train set
train_images, train_labels = train_data['image'], train_data['label']
train_images = jnp.reshape(train_images, (len(train_images), num_pixels))
train_labels = one_hot(train_labels, num_labels)

# # Full test set
test_images, test_labels = test_data['image'], test_data['label']
test_images = jnp.reshape(test_images, (len(test_images), num_pixels))
test_labels = one_hot(test_labels, num_labels)

print('Train:', train_images.shape, train_labels.shape)
print('Test:', test_images.shape, test_labels.shape)

CUDA_VISIBLE_DEVICES = -1
TF_CPP_MIN_LOG_LEVEL = 2
Num GPUs Available:  0
Num CPUs Available:  1

 [name: "/device:CPU:0"
device_type: "CPU"
memory_limit: 268435456
locality {
}
incarnation: 3256589239373377352
xla_global_id: -1
] 

Train: (60000, 784) (60000, 10)
Test: (10000, 784) (10000, 10)


In [26]:
import time

def get_train_batches():
  
    # as_supervised=True gives us the (image, label) as a tuple instead of a dict
    ds = tfds.load(name='mnist', split='train', as_supervised=True, data_dir=data_dir)
    
    # You can build up an arbitrary tf.data input pipeline
    ds = ds.batch(batch_size).prefetch(1)
    
    # tfds.dataset_as_numpy converts the tf.data.Dataset into an iterable of NumPy arrays
    return tfds.as_numpy(ds)

def train(params):
    for epoch in range(num_epochs):
        
        start_time = time.time()
        for x, y in get_train_batches():
            x = jnp.reshape(x, (len(x), num_pixels))
            y = one_hot(y, num_labels)
            params = update(params, x, y)
        
        epoch_time = time.time() - start_time

        train_acc = accuracy(params, train_images, train_labels)
        test_acc = accuracy(params, test_images, test_labels)
        print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
        print("Training set accuracy {}".format(train_acc))
        print("Test set accuracy {}".format(test_acc))

# Train Model
train(params)

Epoch 0 in 2.91 sec
Training set accuracy 0.9253333210945129
Test set accuracy 0.9268999695777893
Epoch 1 in 2.60 sec
Training set accuracy 0.9428499937057495
Test set accuracy 0.9413999915122986
Epoch 2 in 2.56 sec
Training set accuracy 0.9532666802406311
Test set accuracy 0.9513999819755554
Epoch 3 in 2.63 sec
Training set accuracy 0.9600499868392944
Test set accuracy 0.955299973487854
Epoch 4 in 2.71 sec
Training set accuracy 0.965149998664856
Test set accuracy 0.9602999687194824
Epoch 5 in 2.66 sec
Training set accuracy 0.9691833257675171
Test set accuracy 0.9629999995231628
Epoch 6 in 2.87 sec
Training set accuracy 0.9724833369255066
Test set accuracy 0.965499997138977
Epoch 7 in 2.59 sec
Training set accuracy 0.9754166603088379
Test set accuracy 0.9664999842643738
Epoch 8 in 2.50 sec
Training set accuracy 0.9781000018119812
Test set accuracy 0.9680999517440796
Epoch 9 in 2.60 sec
Training set accuracy 0.9803500175476074
Test set accuracy 0.9691999554634094


In [25]:
# Load MNIST Using Keras
import keras

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

print('Train:', x_train.shape, y_train.shape)
print('Test:', x_test.shape, y_test.shape)

Train: (60000, 28, 28) (60000,)
Test: (10000, 28, 28) (10000,)
