<a href="https://colab.research.google.com/github/Dawntown/numerical-analysis-learning/blob/master/pmlintro_ch1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [5]:
import torchvision
from torchvision import datasets
from torchvision import transforms
import numpy as np
import jax
import jax.numpy as jnp
import itertools
from bokeh.io import output_notebook, show
from bokeh.layouts import gridplot
from bokeh.plotting import figure

In [6]:
output_notebook()

## show EMNIST

In [7]:
transform = transforms.Compose([
    lambda img: torchvision.transforms.functional.rotate(img, 90),
    transforms.ToTensor(),
    jnp.array
])

training_data = datasets.EMNIST(root='~/torchdata', split='byclass', download=True, transform=transform)

In [6]:
def plot_one(item):
    image, raw_label = item
    label = training_data.classes[raw_label]
    p = figure(
        title=f"label = {label}", 
        tooltips=[("x", "$x"), ("y", "$y"), ("value", "@image")],
        match_aspect = True
    )
    p.x_range.range_padding = p.y_range.range_padding = 0
    subplot = p.image(image=[np.array(image.squeeze())], x=0, y=0, dw=1, dh=1, level="image")
    
    p.title.align = "center"
    p.axis.visible = False
    p.grid.grid_line_width = 0.5
    
    return p

In [7]:
subplots = list(map(plot_one, itertools.islice(training_data, 25)))
grid = gridplot(subplots, ncols=5, toolbar_location=None, width=150, height=150)
show(grid)



## MNIST demo

In [8]:
HEIGHT = 28 
WIDTH = 28 
CHANNELS = 1 
NUM_PIXELS = HEIGHT * WIDTH * CHANNELS 
NUM_LABELS = 10

In [9]:
transform = transforms.Compose([
    transforms.RandomVerticalFlip(p=1.0),
    transforms.ToTensor(),
    lambda img: img / 255,
    # jnp.array
])
data_train = datasets.MNIST(root="~/torchdata", train=True, download=True, transform=transform)
data_test = datasets.MNIST(root="~/torchdata", train=False, download=True, transform=transform)

In [10]:
def plot_one(item):
    image, label = item
    p = figure(
        title=f"label = {label}", 
        tooltips=[("x", "$x"), ("y", "$y"), ("value", "@image")],
        match_aspect = True
    )
    p.x_range.range_padding = p.y_range.range_padding = 0
    subplot = p.image(image=[np.array(image.squeeze())], x=0, y=0, dw=1, dh=1, level="image")
    
    p.title.align = "center"
    p.axis.visible = False
    p.grid.grid_line_width = 0.5
    
    return p

In [11]:
subplots = list(map(plot_one, itertools.islice(data_train, 25)))
grid = gridplot(subplots, ncols=5, toolbar_location=None, width=150, height=150)
show(grid)

## Neural Network Initialization

In [12]:
from jax import random
LAYER_SIZES = [28*28, 512, 10]
PARAM_SCALE = 0.01

def init_nn_params(sizes, key=random.PRNGKey(0), scale=1e-2):
    def random_layer_params(m, n, key, scale=1e-2):
        w_key, b_key = random.split(key)
        return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))
    keys = random.split(key, len(sizes))
    return [random_layer_params(m, n, k, scale) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]



$\boldsymbol{y} = \mathbf{W}\boldsymbol{x} + \boldsymbol{b}: \mathbb{R}^{m} \to \mathbb{R}^n$,
where $\boldsymbol{y}\in\mathbb{R}^{n}$, $\mathbf{W}\in\mathbb{R}^{n\times m}$, $\boldsymbol{x}\in\mathbb{R}^{m}$, and $\boldsymbol{b}\in\mathbb{R}^{n}$. 

In [13]:
params = init_nn_params(LAYER_SIZES, random.PRNGKey(0), scale=PARAM_SCALE)

In [14]:
import jax.numpy as jnp
from jax.nn import swish

def predict(params, image):
    activations  = image
    for w, b in params[:-1]:
        outputs = jnp.dot(w, activations) + b
        activations = swish(outputs)
        
    final_w, final_b = params[-1]
    logits = jnp.dot(final_w, activations) + final_b
    return logits

$\text{swish}(x) = x*\text{sigmoid}(\beta x)$

In [15]:
random_flattened_image = random.normal(random.PRNGKey(1), (28*28*1,))
preds = predict(params, random_flattened_image)
print(preds.shape)

(10,)


In [28]:
from jax import vmap
batched_predict = vmap(predict, in_axes=(None, 0))

In [29]:
random_flattened_images = random.normal(random.PRNGKey(1), (32, 28*28*1))
batched_preds = batched_predict(params, random_flattened_images)
print(batched_preds.shape)

(32, 10)


In [30]:
from jax.nn import logsumexp

def loss(params, images, targets):
    logits = batched_predict(params, images)
    log_preds = logits - logsumexp(logits)
    return - jnp.mean(targets * log_preds)

In [31]:
from jax import  grad, value_and_grad

INIT_LR = 1.0
DECAY_RATE = 0.95
DECAY_STEPS = 5

In [32]:

def update(params, x, y, epoch_number):
    grads = grad(loss)(params, x, y)
    lr = INIT_LR * DECAY_RATE ** (epoch_number / DECAY_STEPS)
    return [(w - lr * dw, b - lr * db) for (w, b), (dw, db) in zip(params, grads)]

def update2(params, x, y, epoch_number):
    loss_value, grads = value_and_grad(loss)(params, x, y)
    lr = INIT_LR * DECAY_RATE ** (epoch_number / DECAY_STEPS)
    return [(w - lr * dw, b - lr * db) for (w, b), (dw, db) in zip(params, grads)], loss_value

In [33]:
from jax.nn import one_hot
from torch.utils.data import DataLoader
import time

train_loader = DataLoader(data_train, batch_size=64, num_workers=16)
test_loader = DataLoader(data_test, batch_size=64, num_workers=16)

num_epochs = 100



In [34]:
def batch_accuracy(params, images, targets):
    images = jnp.reshape(images, (len(images), NUM_PIXELS))
    predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
    return jnp.mean(predicted_class == targets)

def accuracy(params, data):
    accs = []
    for images, targets in data:
        images = jnp.array(images.numpy())
        targets = jnp.array(targets.numpy())
        accs.append(batch_accuracy(params, images, targets))
    return jnp.mean(jnp.array(accs))
    
import time

In [68]:
for epoch in range(num_epochs):
    start_time = time.time()
    losses = []
    for x, y in train_loader:
        x = jnp.reshape(jnp.array(x.numpy()), (len(x), NUM_PIXELS))
        y = one_hot(jnp.array(y.numpy()), NUM_LABELS)
        params, loss_value = update2(params, x, y, epoch)
        losses.append(loss_value)
        
    epoch_time = time.time() - start_time
    
    start_time = time.time()
    train_acc = accuracy(params, train_loader)
    test_acc = accuracy(params, test_loader)
    eval_time = time.time() - start_time
    
    print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time)) 
    print("Eval in {:0.2f} sec".format(eval_time)) 
    print("Training set loss {}".format(jnp.mean(jnp.array(losses)))) 
    print("Training set accuracy {}".format(train_acc)) 
    print("Test set accuracy {}".format(test_acc))

Epoch 0 in 8.77 sec
Eval in 3.64 sec
Training set loss 0.6450716853141785
Training set accuracy 0.20262527465820312
Test set accuracy 0.203125
Epoch 1 in 8.72 sec
Eval in 3.61 sec
Training set loss 0.6448829174041748
Training set accuracy 0.20903851091861725
Test set accuracy 0.20879776775836945
Epoch 2 in 8.75 sec
Eval in 3.63 sec
Training set loss 0.6446545720100403
Training set accuracy 0.21283648908138275
Test set accuracy 0.21228104829788208
Epoch 3 in 8.80 sec
Eval in 3.66 sec
Training set loss 0.6443791389465332
Training set accuracy 0.2238806039094925
Test set accuracy 0.224920392036438
Epoch 4 in 9.00 sec
Eval in 3.65 sec
Training set loss 0.6440461874008179
Training set accuracy 0.24262060225009918
Test set accuracy 0.2427348792552948
Epoch 5 in 8.86 sec
Eval in 3.64 sec
Training set loss 0.6436436772346497
Training set accuracy 0.2632429301738739
Test set accuracy 0.26214173436164856
Epoch 6 in 8.77 sec
Eval in 3.62 sec
Training set loss 0.6431568264961243
Training set accur

## Add JIT Compilation

In [37]:
from jax import jit

@jit
def update2(params, x, y, epoch_number):
    loss_value, grads = value_and_grad(loss)(params, x, y)
    lr = INIT_LR * DECAY_RATE ** (epoch_number / DECAY_STEPS)
    return [(w - lr * dw, b - lr * db) for (w, b), (dw, db) in zip (params, grads)], loss_value


@jit
def batch_accuracy(params, images, targets):
    images = jnp.reshape(images, (len(images), NUM_PIXELS))
    predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
    return jnp.mean(predicted_class == targets)

In [38]:
for epoch in range(num_epochs):
    start_time = time.time()
    losses = []
    for x, y in train_loader:
        x = jnp.reshape(jnp.array(x.numpy()), (len(x), NUM_PIXELS))
        y = one_hot(jnp.array(y.numpy()), NUM_LABELS)
        params, loss_value = update2(params, x, y, epoch)
        losses.append(loss_value)
        
    epoch_time = time.time() - start_time
    
    start_time = time.time()
    train_acc = accuracy(params, train_loader)
    test_acc = accuracy(params, test_loader)
    eval_time = time.time() - start_time
    
    print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time)) 
    print("Eval in {:0.2f} sec".format(eval_time)) 
    print("Training set loss {}".format(jnp.mean(jnp.array(losses)))) 
    print("Training set accuracy {}".format(train_acc)) 
    print("Test set accuracy {}".format(test_acc))

Epoch 0 in 4.66 sec
Eval in 7.03 sec
Training set loss 0.6459628939628601
Training set accuracy 0.1043943241238594
Test set accuracy 0.10260748863220215
Epoch 1 in 4.36 sec
Eval in 6.45 sec
Training set loss 0.6459582448005676
Training set accuracy 0.1043943241238594
Test set accuracy 0.10260748863220215
Epoch 2 in 4.48 sec
Eval in 6.56 sec
Training set loss 0.6459534764289856
Training set accuracy 0.1043943241238594
Test set accuracy 0.10260748863220215
Epoch 3 in 4.29 sec
Eval in 6.62 sec
Training set loss 0.6459482908248901
Training set accuracy 0.1043943241238594
Test set accuracy 0.10260748863220215
Epoch 4 in 4.49 sec
Eval in 6.72 sec
Training set loss 0.6459426879882812
Training set accuracy 0.1043943241238594
Test set accuracy 0.10260748863220215
Epoch 5 in 4.35 sec
Eval in 6.79 sec
Training set loss 0.6459364891052246
Training set accuracy 0.1043943241238594
Test set accuracy 0.10260748863220215
Epoch 6 in 4.41 sec
Eval in 7.01 sec
Training set loss 0.6459295153617859
Training