In [1]:
# https://roberttlange.com/posts/2020/03/blog-post-10/ 
import jax
from typing import Any, Callable, Sequence
import flax
from flax import linen as nn

%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import numpy as np
import jax.numpy as jnp
from jax import grad, jit, vmap, value_and_grad
from jax import random

# Generate key which is used to generate random numbers
key = random.PRNGKey(1)

In [2]:
x = random.uniform(key, (1000, 1000))

%time y = np.dot(x,x)

%time y = jnp.dot(x,x) #only measures dispatch time, block_until_ready actually requires computation time being factored in 
%time y = jnp.dot(x,x).block_until_ready()  

CPU times: user 54.1 ms, sys: 39.3 ms, total: 93.4 ms
Wall time: 16.6 ms
CPU times: user 20.5 ms, sys: 202 ms, total: 223 ms
Wall time: 10.4 ms
CPU times: user 151 ms, sys: 1.94 s, total: 2.09 s
Wall time: 130 ms


In [3]:
#@jax.jit
def ReLU(x):
    return jnp.maximum(0, x)
jit_relu = jit(ReLU)
#jit doesn't work on conditioning on dtype or shape, or making compile-time comparisons among static variables

In [4]:
#regular relu
%time out = ReLU(x)

#jnp relu dispatch time
%time out = jit_relu(x)

#jnp relu actual computation time
%time out = jit_relu(x).block_until_ready() 

CPU times: user 18 ms, sys: 48 μs, total: 18.1 ms
Wall time: 17.9 ms
CPU times: user 16.8 ms, sys: 416 μs, total: 17.2 ms
Wall time: 16.9 ms
CPU times: user 638 μs, sys: 698 μs, total: 1.34 ms
Wall time: 390 μs


In [5]:
def FiniteDiffGrad(x):
    return jnp.array( (ReLU(x + 1e-3) - ReLU(x - 1e-3) ) / (2 * 1e-3) )

#grad is adc 
#automatically differentiate the jitted ReLU activation, and then jit the gradient function
print("Jax Grad:", jit(grad(jit(ReLU)))(2.))
#this is default approximated finite diff grad
print("FD Gradient:", FiniteDiffGrad(2.) )

Jax Grad: 1.0
FD Gradient: 0.99998707


In [6]:
batch_dim = 64
feature_dim = 50
hidden_dim = 256

X = random.normal(key, (batch_dim, feature_dim))

params = [random.normal(key, (hidden_dim, feature_dim)), random.normal(key, (hidden_dim,))]

def relu_layer(params, x):
    # print(params[0].shape)
    # print(x.shape)
    # print(params[1].shape)
    return ReLU(jnp.dot(params[0], x) + params[1] )
#you get output of shape (256,), and you stack it 64 times (which is X.shape[0])

def batch_relu_layer(params, X):
    return ReLU(jnp.dot(X, params[0].T) + params[1] )

def vmap_relu(params, x):
    return jit(
        vmap(
                    #out axes indicates we should stack along first dimension (equivalent to providing X.shape[0] for a jnp.stack)
                    relu_layer, in_axes=(None, 0), out_axes=0
                )
    )
    
out = jnp.stack([
    relu_layer(params, X[i, :]) for i in range(X.shape[0])
    ])

out.shape # (64, 256)

out = batch_relu_layer(params, X)
out = vmap_relu(params, X)

In [7]:
#MLP for MNIST
from jax.scipy.special import logsumexp
import torch
from torchvision import datasets, transforms
from jax.example_libraries import optimizers

import time

In [8]:
batch_size = 100

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=True, download=True,
            transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(0.1307, 0.3081)
            ])),

    batch_size = batch_size, shuffle=True
)

In [9]:
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=False, 
            transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(0.1307, 0.3081)
            ])),
    batch_size=batch_size, shuffle=True
)

In [10]:
def initialize_mlp(sizes, key):
    keys = random.split(key, len(sizes))
    
    #initialize layer with Gauss weights
    def initialize_layer(m, n, key, scale=1e-2):
        weight_key, bias_key = random.split(key)
        return scale * random.normal(weight_key, (n, m)), scale * random.normal(bias_key, (n,))
    return [
        initialize_layer(m, n,k ) for m, n, k in zip(sizes[:-1], sizes[1:], keys)
    ]
    
layer_sizes = [784, 512, 512, 10]

params = initialize_mlp(layer_sizes, key)
#read out params shapes
for p in params:
    print(p[0].shape, p[1].shape)

(512, 784) (512,)
(512, 512) (512,)
(10, 512) (10,)


In [11]:
def forward_pass(params, in_array):
    activations = in_array
    for w, b in params[:-1]:
        activations = relu_layer([w,b], activations)
        
    #logits
    final_w, final_b = params[-1]
    logits = jnp.dot(final_w, activations) + final_b
    return logits - logsumexp(logits)

batch_forward = vmap(forward_pass, in_axes = (None, 0), out_axes=0)

In [12]:
def one_hot(x, key, dtype=jnp.float32):
    #one-hot encode x with size k, and expands dimensions to have middle dimension of 1 at dim 2
    return jnp.array(x[:, None] == jnp.arange(key), dtype  )


#try with tensor x
x = random.uniform(key, (30, 24,8))

x_ = one_hot(x, 8)
print(x_.shape)

(30, 1, 24, 8)


In [13]:
def loss(params, in_arrays, targets):
    preds = batch_forward(params, in_arrays)
    #cross entropy loss as sum form (don't specify axis so as to have defined secnod dimension as 1)
    return -jnp.sum(preds * targets)

def accuracy(params, data_loader): 
    acc_total = 0
    for batch_idx, (data, target) in enumerate(data_loader):
        images = jnp.array(data).reshape(data.size(0), 784)
        #images are shaped as (100, 784)
        targets = one_hot(jnp.array(target), 10)
        #resulting target reshaped to (100, 10)
        target_class = jnp.argmax(targets, axis=1)
        #highest value prediction gest the class label
        predicted_class = jnp.argmax(batch_forward(params, images), axis=1)
        #predicted class as just (100,)
        acc_total += jnp.sum(predicted_class == target_class)
    return acc_total/len(data_loader.dataset)

In [14]:
@jit
def update(params, x, y, opt_state): 
    #eval both loss and gradient of loss (as tuple return)
    value, grads = value_and_grad(loss)(params, x, y)
    #update optimizer and get new state
    opt_state = opt_update(0, grads, opt_state)
    return get_params(opt_state), opt_state, value
    
step_size = 1e-3
opt_init, opt_update, get_params = optimizers.adam(step_size)
opt_state = opt_init(params)

num_epochs = 1
num_classes = 10

In [15]:
def run_mnist_training_loop(num_epochs, opt_state, net_type="MLP"):
    """ Implements a learning loop over epochs. """
    # Initialize placeholder for loggin
    log_acc_train, log_acc_test, train_loss = [], [], []

    # Get the initial set of parameters
    params = get_params(opt_state)

    # Get initial accuracy after random init
    train_acc = accuracy(params, train_loader)
    test_acc = accuracy(params, test_loader)
    log_acc_train.append(train_acc)
    log_acc_test.append(test_acc)

    # Loop over the training epochs
    for epoch in range(num_epochs):
        start_time = time.time()
        for batch_idx, (data, target) in enumerate(train_loader):
            if net_type == "MLP":
                # Flatten the image into 784 vectors for the MLP
                x = jnp.array(data).reshape(data.size(0), 28*28)
            elif net_type == "CNN":
                # No flattening of the input required for the CNN
                x = jnp.array(data)
            y = one_hot(jnp.array(target), num_classes)
            params, opt_state, loss = update(params, x, y, opt_state)
            train_loss.append(loss)

        epoch_time = time.time() - start_time
        train_acc = accuracy(params, train_loader)
        test_acc = accuracy(params, test_loader)
        log_acc_train.append(train_acc)
        log_acc_test.append(test_acc)
        print("Epoch {} | T: {:0.2f} | Train A: {:0.3f} | Test A: {:0.3f}".format(epoch+1, epoch_time,
                                                                    train_acc, test_acc))

    return train_loss, log_acc_train, log_acc_test


train_loss, train_log, test_log = run_mnist_training_loop(num_epochs,
                                                        opt_state,
                                                        net_type="MLP")

# # Plot the loss curve over time
# from helpers import plot_mnist_performance
# plot_mnist_performance(train_loss, train_log, test_log,
#                        "MNIST MLP Performance")

Epoch 1 | T: 5.13 | Train A: 0.971 | Test A: 0.968


MLA