In [None]:
import jax
import jax.numpy as jnp
from jax import random

import numpy as np
import collections

In [None]:
LAYER_SIZES = [200*200*3, 2048, 1024, 2]
PARAM_SCALE = 0.01

In [None]:
def random_layer_params(m, n, key, scale=1e-2):
    w_key, b_key = random.split(key)
    return (scale * random.normal(w_key, (n, m)), 
        scale * random.normal(b_key, (n,)))

def init_network_params(sizes, key=random.key(0), scale=0.01):
    keys = random.split(key, len(sizes)-1)
    return [random_layer_params(m, n, k, scale) 
        for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

In [None]:
key = random.key(42)
params = init_network_params(LAYER_SIZES, key, scale=PARAM_SCALE)
params

In [None]:
shapes = jax.tree.map(lambda p: p.shape, params)

for i, shape in enumerate(shapes):
    print(i, shape)

In [None]:
Point = collections.namedtuple('Point', ['x', 'y'])

example_pytree = [
    {
        'a': [1, 2, 3],
        'b': jnp.array([1, 2, 3]),
        'c': np.array([1, 2, 3]),
    },
    [42, [44, 46], None],
    31337,
    (50, (60, 70)),
    Point(640, 480),
    collections.OrderedDict([('a', 100), ('b', 200)]),
    'some string'
]

jax.tree.leaves(example_pytree)

In [None]:
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt

In [None]:
training_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

In [None]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

In [None]:
HEIGHT, WIDTH, CHANNELS = 28, 28, 1
NUM_PIXELS = HEIGHT * WIDTH * CHANNELS
NUM_LABELS = 10

In [None]:
LAYER_SIZES = [28 * 28, 512, 10]
PARAM_SCALE = 0.1

In [None]:
from jax import grad, jit, vmap, value_and_grad
from jax import random
from jax.nn import swish, logsumexp, one_hot

In [None]:
def init_network_params(sizes, key=random.key(0), scale=1e-2):
    """Initialize all layers"""

    def random_layer_params(m, n, key, scale=1e-2):
        """A helper function"""
        w_key, b_key = random.split(key)
        return scale * random.normal(w_key, (n,m)), scale * random.normal(b_key, (n,))

    keys = random.split(key, len(sizes))
    return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

In [None]:
init_params = init_network_params(LAYER_SIZES, random.key(0), scale=PARAM_SCALE)

In [None]:
def predict(params, image):
    """Function for per-example predictions."""
    activations = image
    for w, b in params[:-1]:
        outputs = jnp.dot(w, activations) + b
        activations = swish(outputs)
    
    final_w, final_b = params[-1]
    logits = jnp.dot(final_w, activations) + final_b
    return logits

In [None]:
batched_predict = vmap(predict, in_axes=(None, 0))

In [None]:
INIT_LR = 1.0
DECAY_RATE = 0.95
DECAY_STEPS = 5

In [None]:
def loss(params, images, targets):
    """Categorical cross entropy loss."""
    logits = batched_predict(params, images)
    log_preds = logits - logsumexp(logits)
    return -jnp.mean(targets*log_preds)

@jax.jit
def update(params, x, y, epoch_number):
    print(f"Params shapes: {jax.tree.map(lambda p: p.shape, params)}")
    loss_value, grads = value_and_grad(loss)(params, x, y)
    print(f"Grads shapes: {jax.tree.map(lambda g: g.shape, grads)}")
    lr = INIT_LR * DECAY_RATE ** (epoch_number / DECAY_STEPS)
    return [(w - lr * dw, b - lr * db) for (w, b), (dw, db) in zip(params, grads)], loss_value

In [None]:
x, y = next(iter(train_dataloader))
x = x.numpy().reshape(64, 28*28)
x = jnp.reshape(x, (len(x), NUM_PIXELS))
y = one_hot(y.numpy(), NUM_LABELS)

In [None]:
params, loss_value = update(init_params, x, y, 0)

### Functions for working with pytrees

In [None]:
params = init_network_params(LAYER_SIZES, key, scale=PARAM_SCALE)

In [None]:
scaled_params = jax.tree.map(lambda p: 10 * p, params)

In [None]:
some_pytree = [
    [1, 1, 1],
    [
        [10, 10, 10], [20, 20]
    ]
]

In [None]:
jax.tree.map(lambda p: p+1, some_pytree)

In [None]:
leaves, struct = jax.tree.flatten(some_pytree)

In [None]:
print(leaves)
print(struct)

In [None]:
updated_leaves = map(lambda x: x+1, leaves)

In [None]:
jax.tree.unflatten(struct, updated_leaves)

In [None]:
from jax.flatten_util import ravel_pytree

In [None]:
leaves, unflatten_func = ravel_pytree(some_pytree)

In [None]:
print(leaves)
print(unflatten_func)

In [None]:
unflatten_func(leaves)

### Reducing a tree

In [None]:
jax.tree.reduce(lambda accumulator, value: accumulator+value, some_pytree, initializer=0)

In [None]:
import math
from collections import namedtuple

### Transposing a pytree

In [None]:
Point = namedtuple('Point', ['x', 'y'])

In [None]:
points = [
    Point(0.0, 0.0),
    Point(3.0, 0.0),
    Point(0.0, 4.0)
]

In [None]:
def rotate_point(p, theta):
    x = p.x * math.cos(theta) - p.y * math.sin(theta)
    y = p.x * math.sin(theta) + p.y * math.cos(theta)
    return Point(x, y)

In [None]:
rotate_point(points[1], math.pi)

In [None]:
jax.vmap(rotate_point, in_axes=(0, None))(points, math.pi)

In [None]:
jax.tree.structure(points)

In [None]:
jax.tree.structure(points[0])

In [None]:
points_t = jax.tree.transpose(
    outer_treedef=jax.tree.structure([0 for p in points]),
    inner_treedef=jax.tree.structure(points[0]),
    pytree_to_transpose=points
)
points_t

In [None]:
points_t_array = Point(jnp.array(points_t.x), jnp.array(points_t.y))
points_t_array

In [None]:
jax.vmap(rotate_point, in_axes=(0, None))(points_t_array, math.pi)

### Creating custom pytree nodes

In [None]:
class Layer:
    def __init__(self, name, w, b):
        self.name = name
        self.w = w
        self.b = b

In [None]:
h1 = Layer('hidden1', jnp.zeros((100,20)), jnp.zeros((20,)))

In [None]:
pt = [
    jnp.ones(50),
    h1
]

In [None]:
jax.tree.leaves(pt)

In [None]:
jax.tree.map(lambda x: x*10, pt)


In [None]:
def flatten_layer(container):
    flat_contents = [container.w, container.b]
    aux_data = container.name
    return flat_contents, aux_data

def unflatten_layer(aux_data, flat_contents):
    return Layer(aux_data, *flat_contents)

In [None]:
jax.tree_util.register_pytree_node(
    Layer,
    flatten_layer,
    unflatten_layer
)

In [None]:
h1 = Layer('hidden1', jnp.zeros((100, 20)), jnp.zeros((20,)))

In [None]:
pt = [
    jnp.ones(50), 
    h1
]

In [None]:
jax.tree.leaves(pt)

In [None]:
jax.tree.map(lambda x: x*10, pt)

In [None]:
jax.tree.leaves(pt)

In [None]:
pt2 = jax.tree.map(lambda x: x+1, pt)

In [None]:
pt2

In [None]:
jax.tree.leaves(pt2)