<a href="https://colab.research.google.com/github/alexandrumeterez/ai_notebooks/blob/master/JAX_stuffs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import jax.numpy as np
from jax import grad, jit, vmap
from jax import random

In [0]:
def random_params(in_size, out_size, key, scale=1e-2):
    w_key, b_key = random.split(key)
    w_params, b_params = scale * random.normal(w_key, (out_size, in_size)), \
                        scale * random.normal(b_key, (out_size, ))
    return w_params, b_params

In [0]:
def init_network(sizes, key):
    keys = random.split(key, len(sizes))
    layers = []

    for i in range(1, len(sizes)):
        in_size = sizes[i-1]
        out_size = sizes[i]
        k = keys[i-1]

        layers.append(random_params(in_size, out_size, k))
    return layers

In [0]:
layers = [784, 512, 512, 10]
params = init_network(layers, random.PRNGKey(0))

In [5]:
print([p[0].shape for p in params]) # weights
print([p[1].shape for p in params]) # biases

[(512, 784), (512, 512), (10, 512)]
[(512,), (512,), (10,)]


In [0]:
def relu(x):
    return np.maximum(0, x)

def softmax(x):
    return np.exp(x) / np.sum(np.exp(x), axis=0)

def cross_entropy(x, y):
    return -np.mean(x * np.log(y))

In [0]:
def predict(params, image):
    activations = image
    for w, b in params[:-1]:
        activations = relu(np.dot(w, activations) + b)
    last_w, last_b = params[-1]
    logits = np.dot(last_w, activations) + last_b

    return softmax(logits)

In [8]:
random_image = random.normal(random.PRNGKey(1), (28*28, ))
logits = predict(params, random_image)
print(logits.shape)

(10,)


In [0]:
# fucks up with bathces, use vmap
batched_predict = vmap(predict, in_axes=(None, 0))

In [10]:
random_image_batch = random.normal(random.PRNGKey(1), (1, 28*28))
print(random_image_batch.shape)
logits = batched_predict(params, random_image_batch)
print(logits.shape)

(1, 784)
(1, 10)


In [11]:
print(logits[0])

[0.10200009 0.10042951 0.1009826  0.09725185 0.09946808 0.09856772
 0.09960646 0.10060352 0.10020161 0.10088862]


In [0]:
def loss(params, images, targets):
    predictions = batched_predict(params, images)
    loss = cross_entropy(targets, predictions)
    return loss

In [13]:
loss(params, random_image_batch, np.array([0,1,2,3,4,5,6,7,8,9]))

DeviceArray(10.365537, dtype=float32)

In [0]:
@jit
def update(params, x, y):
    grads = grad(loss)(params, x, y)
    updated_params = []
    for (w, b), (dw, db) in zip(params, grads):
        updated_params.append((w - 0.01 * dw, b - 0.01 * db))
    return updated_params

In [0]:
def one_hot(x, k, dtype=np.float32):
  """Create a one-hot encoding of x of size k."""
  return np.array(x[:, None] == np.arange(k), dtype)

In [0]:
batch_size = 256
import numpy as onp
from torch.utils import data
from torchvision.datasets import MNIST

def numpy_collate(batch):
  if isinstance(batch[0], onp.ndarray):
    return onp.stack(batch)
  elif isinstance(batch[0], (tuple,list)):
    transposed = zip(*batch)
    return [numpy_collate(samples) for samples in transposed]
  else:
    return onp.array(batch)

class NumpyLoader(data.DataLoader):
  def __init__(self, dataset, batch_size=1,
                shuffle=False, sampler=None,
                batch_sampler=None, num_workers=0,
                pin_memory=False, drop_last=False,
                timeout=0, worker_init_fn=None):
    super(self.__class__, self).__init__(dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        sampler=sampler,
        batch_sampler=batch_sampler,
        num_workers=num_workers,
        collate_fn=numpy_collate,
        pin_memory=pin_memory,
        drop_last=drop_last,
        timeout=timeout,
        worker_init_fn=worker_init_fn)

class FlattenAndCast(object):
  def __call__(self, pic):
    return onp.ravel(onp.array(pic, dtype=np.float32))

In [0]:
# Define our dataset, using torch datasets
mnist_dataset = MNIST('/tmp/mnist/', download=True, transform=FlattenAndCast())
training_generator = NumpyLoader(mnist_dataset, batch_size=batch_size, num_workers=0)

In [18]:
n_targets = 10
# Get the full train dataset (for checking accuracy while training)
train_images = onp.array(mnist_dataset.train_data).reshape(len(mnist_dataset.train_data), -1)
train_labels = one_hot(onp.array(mnist_dataset.train_labels), n_targets)

# Get full test dataset
mnist_dataset_test = MNIST('/tmp/mnist/', download=True, train=False)
test_images = np.array(mnist_dataset_test.test_data.numpy().reshape(len(mnist_dataset_test.test_data), -1), dtype=np.float32)
test_labels = one_hot(onp.array(mnist_dataset_test.test_labels), n_targets)



In [19]:
for epoch in range(20):
    epoch_loss = 0.0
    n_batches = len(training_generator)
    for i, (x, y) in enumerate(training_generator):
        # print(f"Batch {i+1}/{n_batches}")
        y = one_hot(y, n_targets)
        params = update(params, x, y)
        epoch_loss += loss(params, x, y)
    print(f"Epoch: {epoch+1}\n\tLoss: {epoch_loss / n_batches}")

Epoch: 1
	Loss: 0.05866694822907448
Epoch: 2
	Loss: 0.027958426624536514
Epoch: 3
	Loss: 0.02301833964884281
Epoch: 4
	Loss: 0.020027145743370056
Epoch: 5
	Loss: 0.017832688987255096
Epoch: 6
	Loss: 0.016106652095913887
Epoch: 7
	Loss: 0.014697279781103134
Epoch: 8
	Loss: 0.01351588498800993
Epoch: 9
	Loss: 0.012503830716013908
Epoch: 10
	Loss: 0.011626585386693478
Epoch: 11
	Loss: 0.010854005813598633
Epoch: 12
	Loss: 0.010166967287659645
Epoch: 13
	Loss: 0.009551472961902618
Epoch: 14
	Loss: 0.008995779789984226
Epoch: 15
	Loss: 0.008492236956954002
Epoch: 16
	Loss: 0.008033220656216145
Epoch: 17
	Loss: 0.007611940614879131
Epoch: 18
	Loss: 0.0072233108803629875
Epoch: 19
	Loss: 0.006865765433758497
Epoch: 20
	Loss: 0.006535021588206291
