In [None]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
import jax.scipy as jsp

In [None]:
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

In [None]:
#Function to init layer of size nxm with key and scale
def random_layer_params(m,n,key,scale=1e-2):
    w_key, b_key = random.split(key)
    return scale * random.normal(w_key, (n, m)), scale*random.normal(b_key, (n,))

def init_network_params(sizes, key):
    keys = random.split(key, len(sizes))
    return [random_layer_params(m,n,k) for m,n,k in zip(sizes[:-1], sizes[1:], keys)]


In [None]:
layer_sizes = [784, 512, 512, 47]
step_size = 0.01
num_epochs = 10
batch_size = 128
n_targets = 47
params = init_network_params(layer_sizes, random.PRNGKey(0))

In [None]:
from jax.scipy.special import logsumexp

def relu(x):
    return jnp.maximum(0,x)

def predict(params, image):
    activations = image
    for w, b in params[:-1]:
        outputs = jnp.dot(w, activations) + b
        activations = relu(outputs)

    final_w, finals_b = params[-1]
    logits = jnp.dot(final_w, activations) + finals_b

    return logits - logsumexp(logits)

In [None]:
# This works on single examples
random_flattened_image = random.normal(random.PRNGKey(1), (28 * 28,))
preds = predict(params, random_flattened_image)
print(preds.shape)
print(preds)

In [None]:
#prediction for batches
batched_predict = vmap(predict, in_axes=(None, 0))

In [None]:
#test prediction
random_flattened_images = random.normal(random.PRNGKey(1), (10, 28*28))
batched_preds = batched_predict(params, random_flattened_images)

print(batched_preds.shape)

In [None]:
#utility functions

def one_hot(x, k, dtype=jnp.float32):
    #creates one hot encoding of x of size k
    return jnp.array(x[:, None] == jnp.arange(k), dtype)

def accuracy(params, images, targets):
    target_class = jnp.argmax(targets, axis=1)
    predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
    return jnp.mean(predicted_class == target_class)

def loss(params, images, targets):
    preds = batched_predict(params, images)
    return -jnp.mean(preds * targets)

@jit
def update(params, x , y):
    grads = grad(loss)(params, x, y)
    return [(w - step_size * dw, b - step_size * db)
            for (w, b), (dw, db) in zip(params, grads)]

In [None]:
random_flattened_images = random.normal(random.PRNGKey(1), (10, 28*28))
random_labels = random.normal(random.PRNGKey(0),(10,1))
result = loss(params,random_flattened_images,random_labels)

In [None]:
import numpy as np
from jax.tree_util import tree_map
from torch.utils import data
from torchvision.datasets import EMNIST

def numpy_collate(batch):
    return tree_map(np.asarray, data.default_collate(batch))

class NumpyLoader(data.DataLoader):
  def __init__(self, dataset, batch_size=1,
                shuffle=False, sampler=None,
                batch_sampler=None, num_workers=0,
                pin_memory=False, drop_last=False,
                timeout=0, worker_init_fn=None):
    super(self.__class__, self).__init__(dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        sampler=sampler,
        batch_sampler=batch_sampler,
        num_workers=num_workers,
        collate_fn=numpy_collate,
        pin_memory=pin_memory,
        drop_last=drop_last,
        timeout=timeout,
        worker_init_fn=worker_init_fn)

class FlattenAndCast(object):
  def __call__(self, pic):
    return np.ravel(np.array(pic, dtype=jnp.float32))

In [None]:
# Define our dataset, using torch datasets
mnist_dataset = EMNIST('/tmp/emnist/', split="balanced", download=True, transform=FlattenAndCast())
training_generator = NumpyLoader(mnist_dataset, batch_size=batch_size, num_workers=0)

In [None]:
# Get the full train dataset (for checking accuracy while training)
train_images = np.array(mnist_dataset.train_data).reshape(len(mnist_dataset.train_data), -1)
train_labels = one_hot(np.array(mnist_dataset.train_labels), n_targets)

# Get full test dataset
mnist_dataset_test = EMNIST('/tmp/emnist/', split="balanced", download=True, train=False)
test_images = jnp.array(mnist_dataset_test.test_data.numpy().reshape(len(mnist_dataset_test.test_data), -1), dtype=jnp.float32)
test_labels = one_hot(np.array(mnist_dataset_test.test_labels), n_targets)

In [None]:
images, labels = next(iter(training_generator))
print(images.shape)

In [None]:
import time

for epoch in range(num_epochs):
  start_time = time.time()
  for x, y in training_generator:
    y = one_hot(y, n_targets)
    params = update(params, x, y)
  epoch_time = time.time() - start_time

  train_acc = accuracy(params, train_images, train_labels)
  test_acc = accuracy(params, test_images, test_labels)
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))

In [None]:
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(10,7))
rows = 1
columns= 20

for i in range(20):
    fig.add_subplot(rows,columns,i+1)
    img_index = i
    images, labels = next(iter(training_generator))
    preds = predict(params, images[img_index])
    plt.imshow(images[img_index].reshape(28,28), cmap='gray')
    print(f"label {labels[img_index]}, prediction {jnp.argmax(preds)}")

In [None]:
img_index = 5
fig = plt.figure(figsize=(10,7))
images, labels = next(iter(training_generator))
preds = predict(params, images[img_index])
fig.add_subplot(1,2,1)
plt.imshow(images[img_index].reshape(28,28), cmap='gray')

x = jnp.linspace(-3, 3, 7)
window = jsp.stats.norm.pdf(x) * jsp.stats.norm.pdf(x[:, None])
smooth_img = jsp.signal.convolve(images[img_index].reshape(28,28), window, mode='valid')
print(x.shape)
print(window.shape)
print(smooth_img.shape)
fig.add_subplot(1,2,2)
plt.imshow(smooth_img, cmap='gray')
print(f"label {labels[img_index]}, prediction {jnp.argmax(preds)}")