In [6]:
from jax.scipy.special import logsumexp
from jax.experimental import optimizers
import jax.numpy as np
from jax import grad, jit, vmap, value_and_grad
from jax import random

import torch
from torchvision import datasets, transforms

import time

In [2]:
batch_size = 100

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data',train=True, download=True,transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,),(0.3081,))
    ])),batch_size=batch_size,shuffle=True)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data',train=False, download=True,transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,),(0.3081,))
    ])),batch_size=batch_size,shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=9912422.0), HTML(value='')))


Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=28881.0), HTML(value='')))


Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1648877.0), HTML(value='')))


Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=4542.0), HTML(value='')))


Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw



In [7]:
key = random.PRNGKey(1)

In [20]:
def ReLU(x):
    return np.maximum(0,x)

In [21]:
def relu_layer(params, x):
    return ReLU(np.dot(params[0],x)+params[1])

In [50]:
def initialize_mlp(sizes, key):
    keys = random.split(key, len(sizes))
    def initialize_layer(m,n,key,scale=1e-2):
        w_key, b_key = random.split(key)
        return scale*random.normal(w_key,(n,m)), scale*random.normal(b_key,(n,))
    return [initialize_layer(m,n,k) for m,n,k in zip(sizes[:-1],sizes[1:],keys)]

layer_sizes = [784,512,512,10]
params = initialize_mlp(layer_sizes,key)

In [53]:
print(params[0][0].shape)
print(params[0][1].shape)
print(params[1][0].shape)
print(params[1][1].shape)
print(params[2][0].shape)
print(params[2][1].shape)

(512, 784)
(512,)
(512, 512)
(512,)
(10, 512)
(10,)


In [28]:
def forward_pass(params, in_array):
    ''' compute the forward pass for each example individually'''
    activations = in_array
    # loop over the ReLU hidden layers
    for w,b in params[:-1]:
        activations = relu_layer([w,b],activations)
    # performs final trafo to logits
    final_w, final_b = params[-1]
    logits = np.dot(final_w,activations)+final_b
    return logits-logsumexp(logits)

batch_forward = vmap(forward_pass, in_axes=(None,0),out_axes=0)

In [32]:
def one_hot(x,k,dtype=np.float32):
    ''' create a one=hot encoding of x of size k'''
    return np.array(x[:,None]==np.arange(k),dtype)
def loss(params, in_arrays, targets):
    '''compute the multi-class cross-entorpy loss'''
    preds = batch_forward(params, in_arrays)
    return -np.sum(preds*targets)
def accuracy(params, data_loader):
    acc_total=0
    for batch_idx, (data, target) in enumerate(data_loader):
        images = np.array(data).reshape(data.size(0),28*28)
        targets = one_hot(np.array(target),num_classes)
        
        target_class = np.argmax(targets, axis=1)
        predicted_class = np.argmax(batch_forward(params,images),axis=1)
        acc_total += np.sum(predicted_class==target_class)
    return acc_total/len(data_loader.dataset)

In [35]:
'''
jax.value_and_grad(fun) create a function that evaluates both fun and the gradient of fun
opt_init, opt_update, get_params = optimizers.adam(step_size)
    init_func(params): 
        args: 
            params: pytree representing the initial parameters
        returns:
            a pytree representing the initial optimizer state, which includes the initial parameters and may also include auxiliary values like initial momentum
    update_fun(step, grads, opt_state)
        args:
            step: integer representing the step index
            grads: a pytree representing the gradients to be used in updating the optimizer state
            opt_state: a pytree representing the optimizer state to be updated
        returns:
            a pytree with the same structure as the 'opt_state' argument representing the updated optimizer state
    get_params(opt_state):
        args:
            opt_state: pytree representing an optimizer state
        returns:
            a pytree representing the parameters extracted from 'opt_state', such that the invariant 'params==get_params(init_fun(params))' holds true
        
'''
@jit
def update(params, x, y, opt_state):
    ''' compute the gradient for a batch and update the parameter'''
    value, grads = value_and_grad(loss)(params, x, y)
    opt_state = opt_update(0,grads,opt_state)
    return get_params(opt_state),opt_state, value
step_size=1e-3
opt_init, opt_update, get_params = optimizers.adam(step_size)
opt_state = opt_init(params)

num_epochs = 10
num_classes = 10

In [37]:
def run_mnist_training_loop(num_epochs, opt_state, net_type="MLP"):
    '''implements a learning loop over epochs'''
    # initialize placeholder for loggin
    log_acc_train, log_acc_test, train_loss = [], [], []
    # get the initial set of parameters
    params = get_params(opt_state)
    
    # get initial accuracy after random init
    train_acc = accuracy(params, train_loader)
    test_acc = accuracy(params, test_loader)
    log_acc_train.append(train_acc)
    log_acc_test.append(test_acc)
    
    # loop over the training epochs
    for epoch in range(num_epochs):
        start_time = time.time()
        for batch_idx, (data, target) in enumerate(train_loader):
            if net_type=="MLP":
                x = np.array(data).reshape(data.size(0),28*28)
            elif net_type == "CNN":
                x=np.array(data)
            y = one_hot(np.array(target),num_classes)
            params, opt_state, loss = update(params, x, y, opt_state)
            train_loss.append(loss)
            
        epoch_time = time.time() - start_time
        train_acc = accuracy(params, train_loader)
        test_acc = accuracy(params, test_loader)
        log_acc_train.append(train_acc)
        log_acc_test.append(test_acc)
        print("Epoch {} | T: {:0.2f} | Train A: {:0.3f} | Test A: {:0.3f}".format(epoch+1, epoch_time,train_acc, test_acc))
    return train_loss, log_acc_train, log_acc_test

train_loss, train_log, test_log = run_mnist_training_loop(num_epochs, opt_state, net_type="MLP")



Epoch 1 | T: 12.47 | Train A: 0.970 | Test A: 0.963
Epoch 2 | T: 13.18 | Train A: 0.983 | Test A: 0.975
Epoch 3 | T: 12.52 | Train A: 0.990 | Test A: 0.979
Epoch 4 | T: 12.51 | Train A: 0.991 | Test A: 0.978
Epoch 5 | T: 12.73 | Train A: 0.994 | Test A: 0.980
Epoch 6 | T: 12.70 | Train A: 0.996 | Test A: 0.981
Epoch 7 | T: 12.42 | Train A: 0.997 | Test A: 0.983
Epoch 8 | T: 12.50 | Train A: 0.996 | Test A: 0.979
Epoch 9 | T: 12.87 | Train A: 0.997 | Test A: 0.979
Epoch 10 | T: 12.83 | Train A: 0.998 | Test A: 0.980


# using the stax API to build Sequential Models - case study: A CNN

In [38]:
from jax.experimental import stax
from jax.experimental.stax import (BatchNorm, Conv, Dense, Flatten, Relu, LogSoftmax)



In [54]:
init_fun, conv_net = stax.serial(Conv(32,(5,5),(2,2),padding="SAME"),
                                BatchNorm(),Relu,
                                Conv(32, (5,5), (2,2), padding="SAME"),
                                BatchNorm(),Relu,
                                Conv(10, (3,3),(2,2),padding="SAME"),
                                BatchNorm(),Relu,
                                Conv(10, (3,3),(2,2),padding="SAME"),Relu,
                                Flatten,
                                Dense(num_classes),
                                LogSoftmax)
_, params = init_fun(key, (batch_size, 1,28,28))
'''
The output returns a function to initialize the parameters of the network as well as a function to apply the forward pass through
the network with. WHen initializing we have to specify the shape of the desired input as well as the batch dimension. Similarly
as before we can then proceed to define the loss and accuracy. The only difference compared to the MLP case is that we no longer flatten the image
'''



'\nThe output returns a function to initialize the parameters of the network as well as a function to apply the forward pass through\nthe network with. WHen initializing we have to specify the shape of the desired input as well as the batch dimension. Similarly\nas before we can then proceed to define the loss and accuracy. The only difference compared to the MLP case is that we no longer flatten the image\n'

In [59]:
print(params[0][0].shape)
print(params[0][1].shape)

(5, 5, 28, 32)
(1, 1, 1, 32)


In [47]:
def accuracy(params, data_loader):
    """
    compute the accuracy for the CNN case (no flattening of input)
    """
    acc_total=0
    for batch_idx, (data, target) in enumerate(data_loader):
        images = np.array(data)
        targets = one_hot(np.array(target),num_classes)
        target_class = np.argmax(targets, axis=1)
        predicted_class = np.argmax(conv_net(params,images),axis=1)
        acc_total += np.sum(predicted_class == target_class)
    return acc_total/len(data_loader.dataset)

def loss(params, images, targets):
    preds = conv_net(params, images)
    return -np.sum(preds*targets)

In [48]:
step_size = 1e-3
opt_init, opt_update, get_params = optimizers.adam(step_size)
opt_state = opt_init(params)
num_epochs = 10

train_loss, train_log, test_log = run_mnist_training_loop(num_epochs,opt_state,net_type="CNN")

Epoch 1 | T: 16.85 | Train A: 0.970 | Test A: 0.965
Epoch 2 | T: 13.81 | Train A: 0.980 | Test A: 0.976
Epoch 3 | T: 13.80 | Train A: 0.984 | Test A: 0.977
Epoch 4 | T: 13.89 | Train A: 0.987 | Test A: 0.979
Epoch 5 | T: 14.52 | Train A: 0.989 | Test A: 0.981
Epoch 6 | T: 13.85 | Train A: 0.990 | Test A: 0.981
Epoch 7 | T: 14.69 | Train A: 0.992 | Test A: 0.981
Epoch 8 | T: 14.07 | Train A: 0.993 | Test A: 0.982
Epoch 9 | T: 14.24 | Train A: 0.993 | Test A: 0.983
Epoch 10 | T: 14.16 | Train A: 0.994 | Test A: 0.981
