In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
from torchvision.transforms import ToTensor

In [2]:
def random_layer_params(m, n, key, scale=1e-2):
    w_key, b_key = random.split(key)
    return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n, ))
    
def init_network_params(sizes, key):
    keys = random.split(key, len(sizes))
    
    print(keys)
    print(sizes[:-1])
    print(sizes[1:])
    return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], 
                                                            sizes[1:], 
                                                            keys)]

layer_sizes = [784, 512, 512,10]
step_size = 0.01
num_epochs = 8
batch_size = 128
n_targets = 10
params = init_network_params(layer_sizes, random.PRNGKey(0))



[[2285895361 1501764800]
 [1518642379 4090693311]
 [ 433833334 4221794875]
 [ 839183663 3740430601]]
[784, 512, 512]
[512, 512, 10]


In [3]:
import jax

In [4]:
jax.numpy.tanh

<CompiledFunction of <function _one_to_one_unop.<locals>.<lambda> at 0x7fb2083b8cb0>>

In [5]:
from jax.scipy.special import logsumexp

def relu(x):
    return jnp.maximum(0, x)

def tanh(x):
    return jnp.tanh(x)

def selu(x):
    return jax.nn.selu(x)


def predict(params, image):
    activations = image
    for w, b in params[:-1]:
        outputs = jnp.dot(w, activations) + b
        activations = relu(outputs)
    
    final_w, final_b = params[-1]
    logits = jnp.dot(final_w, activations) + final_b
    return logits - logsumexp(logits)

In [6]:
random_flattened_images = random.normal(random.PRNGKey(1), (10, 28 * 28))

In [7]:
batched_predict = vmap(predict, in_axes = (None, 0))
batched_predict_alt = vmap(predict, in_axes = (0, 0))

In [8]:
batched_preds = batched_predict(params, random_flattened_images)
print(batched_preds.shape)

(10, 10)


In [115]:
def one_hot(x, k, dtype=jnp.float32):
  """Create a one-hot encoding of x of size k."""
  return jnp.array(x[:, None] == jnp.arange(k), dtype)
  
def accuracy(params, images, targets):
    target_class = jnp.argmax(targets, axis=1)
    predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
    return jnp.mean(predicted_class == target_class)

def loss(params, images, targets):
    preds = batched_predict(params, images)
    
    return -jnp.mean(preds * targets)

@jit
def update(params, x, y):
    grads = grad(loss)(params, x, y)
    return [(w - step_size * dw, b - step_size * db)
          for (w, b), (dw, db) in zip(params, grads)]

In [108]:
import numpy as np
from torch.utils import data
from torchvision.datasets import MNIST

def numpy_collate(batch):
#     print('called')
#     print(type(batch))
#     print(len(batch))
    #print(batch[0])
    if isinstance(batch[0], np.ndarray):
        return np.stack(batch)
    elif isinstance(batch[0], (tuple,list)):
        transposed = zip(*batch)
        return [numpy_collate(samples) for samples in transposed]
    else:
        return np.array(batch)

class NumpyLoader(data.DataLoader):
    def __init__(self, dataset, batch_size=1,
                shuffle=False, sampler=None,
                batch_sampler=None, num_workers=0,
                pin_memory=False, drop_last=False,
                timeout=0, worker_init_fn=None):
        super(self.__class__, self).__init__(dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            sampler=sampler,
            batch_sampler=batch_sampler,
            num_workers=num_workers,
            collate_fn=numpy_collate,
            pin_memory=pin_memory,
            drop_last=drop_last,
            timeout=timeout,
            worker_init_fn=worker_init_fn)

# This is applied when the __getitem__ method in the dataset (mnist_dataset below)
# is invoked
class FlattenAndCast(object):
    def __call__(self, pic):
        #print(pic)
        return np.ravel(np.array(pic, dtype=jnp.float32)).astype(jnp.float32)

In [109]:
# Define our dataset, using torch datasets
mnist_dataset = MNIST('data/mnist/', 
                      download=True,
                      transform=FlattenAndCast())
training_generator = NumpyLoader(mnist_dataset, 
                                 batch_size=batch_size,
                                 num_workers=0)

In [116]:
# Get the full train dataset ( for checking accuracy while training)
train_images = np.array(mnist_dataset.data[500:, :, :]).reshape(len(mnist_dataset.data[500:]), - 1) 
train_labels = one_hot(np.array(mnist_dataset.targets[500:]), n_targets)                                                                  

# Get test dataset
test_images = np.array(mnist_dataset.data[:500, :, :]).reshape(len(mnist_dataset.data[:500]), - 1) 
test_labels = one_hot(np.array(mnist_dataset.targets[:500]), n_targets) 

In [117]:
import time

for epoch in range(num_epochs):
    start_time = time.time()
    for x, y in training_generator:
        y = one_hot(y, n_targets)
        params = update(params, x, y)
    epoch_time = time.time() - start_time

    train_acc = accuracy(params, train_images, train_labels)
    test_acc = accuracy(params, test_images, test_labels)
    print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
    print("Training set accuracy {}".format(train_acc))
    print("Test set accuracy {}".format(test_acc))

Epoch 0 in 5.89 sec
Training set accuracy 0.942874014377594
Test set accuracy 0.9500000476837158
Epoch 1 in 5.57 sec
Training set accuracy 0.9526386857032776
Test set accuracy 0.956000030040741
Epoch 2 in 5.16 sec
Training set accuracy 0.960084080696106
Test set accuracy 0.9600000381469727
Epoch 3 in 5.27 sec
Training set accuracy 0.9652605652809143
Test set accuracy 0.9620000720024109
Epoch 4 in 5.37 sec
Training set accuracy 0.9689244031906128
Test set accuracy 0.968000054359436
Epoch 5 in 5.19 sec
Training set accuracy 0.9720168709754944
Test set accuracy 0.9720000624656677
Epoch 6 in 5.29 sec
Training set accuracy 0.974907636642456
Test set accuracy 0.9720000624656677
Epoch 7 in 5.29 sec
Training set accuracy 0.9772941470146179
Test set accuracy 0.9740000367164612
