# Функции для ускорения работы кода

In [1]:
import numpy as np
import jax
# Generate key which is used to generate random numbers
key = jax.random.PRNGKey(1)



Вместо обычного массива используется DeviceArray, вычисления с котороым, во-первых, без каких-либо указаний производятся на аккселераторах GPU или TPU, а во-вторых, производятся асинхронно (для этого приходится вызывать метод block_until_ready).

In [2]:
x = np.zeros(10)
x

array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [3]:
y = jax.numpy.zeros(10)
y

DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)

In [6]:
x = np.random.rand(1000, 1000)
y = jax.numpy.array(x)

%timeit -n 1 -r 1 np.dot(x,x)
%timeit -n 1 -r 1 jax.numpy.dot(y,y).block_until_ready()

1 loop, best of 1: 64.2 ms per loop
1 loop, best of 1: 34.4 ms per loop


Градиент вычисляется аналитически по правилу дифференциации сложных функций (chain rule), причём справляется с циклами, условиями и рекурсией.

In [7]:
def f(x):
  return 3*x**2 + 2*x + 5

def f_prime(x):
  return 6*x + 2

print(jax.grad(f)(1.0))
print(f_prime(1.0))

8.0
8.0


JIT (just-in-time) компиляция - технология для ускорения программного кода.

In [8]:
x = np.random.rand(1000,1000)
y = jax.numpy.array(x)

def f(x):
  for _ in range(10):
      x = 0.5*x + 0.1*jax.numpy.sin(x)
  return x

g = jax.jit(f)

%timeit -n 5 -r 5 f(y).block_until_ready()
%timeit -n 5 -r 5 g(y).block_until_ready()

5 loops, best of 5: 70.4 ms per loop
5 loops, best of 5: 44.9 ms per loop


VMap (vectorizing map) - технология для быстрой организации в батчи

In [9]:
mat = jax.random.normal(key, (150, 100))
batched_x = jax.random.normal(key, (10, 100))

def apply_matrix(v):
  return jax.numpy.dot(mat, v)

In [10]:
def naively_batched_apply_matrix(v_batched):
  return jax.numpy.stack([apply_matrix(v) for v in v_batched])

print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()

Naively batched
The slowest run took 52.40 times longer than the fastest. This could mean that an intermediate result is being cached.
1000 loops, best of 5: 1.6 ms per loop


In [11]:
@jax.jit
def vmap_batched_apply_matrix(v_batched):
  return jax.vmap(apply_matrix)(v_batched)

print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()

Auto-vectorized with vmap
The slowest run took 482.28 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 5: 33.5 µs per loop


# Многослойный перцептрон

In [12]:
from jax.experimental import optimizers

import torch
from torchvision import datasets, transforms

import time



In [13]:
batch_size = 100

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=batch_size, shuffle=True)

In [14]:
def initialize_mlp(sizes, key):
    """ Initialize the weights of all layers of a linear layer network """
    keys = jax.random.split(key, len(sizes))
    # Initialize a single layer with Gaussian weights -  helper function
    def initialize_layer(m, n, key, scale=1e-2):
        w_key, b_key = jax.random.split(key)
        return scale * jax.random.normal(w_key, (n, m)), scale * jax.random.normal(b_key, (n,))
    return [initialize_layer(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

layer_sizes = [784, 512, 512, 10]
# Return a list of tuples of layer weights
params = initialize_mlp(layer_sizes, key)

In [15]:
def relu_layer(params, x):
    return jax.numpy.maximum(jax.numpy.dot(params[0], x) + params[1], 0)

def forward_pass(params, in_array):
    """ Compute the forward pass for each example individually """
    activations = in_array

    # Loop over the ReLU hidden layers
    for w, b in params[:-1]:
        activations = relu_layer([w, b], activations)

    # Perform final trafo to logits
    final_w, final_b = params[-1]
    logits = jax.numpy.dot(final_w, activations) + final_b
    return logits - jax.scipy.special.logsumexp(logits)

# Make a batched version of the `predict` function
batch_forward = jax.vmap(forward_pass, in_axes=(None, 0), out_axes=0)

In [16]:
def one_hot(x, k, dtype=jax.numpy.float32):
    """Create a one-hot encoding of x of size k """
    return jax.numpy.array(x[:, None] == jax.numpy.arange(k), dtype)

def loss(params, in_arrays, targets):
    """ Compute the multi-class cross-entropy loss """
    preds = batch_forward(params, in_arrays)
    return -jax.numpy.sum(preds * targets)

def accuracy(params, data_loader):
    """ Compute the accuracy for a provided dataloader """
    acc_total = 0
    for batch_idx, (data, target) in enumerate(data_loader):
        images = jax.numpy.array(data).reshape(data.size(0), 28*28)
        targets = one_hot(jax.numpy.array(target), num_classes)

        target_class = jax.numpy.argmax(targets, axis=1)
        predicted_class = jax.numpy.argmax(batch_forward(params, images), axis=1)
        acc_total += jax.numpy.sum(predicted_class == target_class)
    return acc_total/len(data_loader.dataset)

In [17]:
@jax.jit
def update(params, x, y, opt_state):
    """ Compute the gradient for a batch and update the parameters """
    value, grads = jax.value_and_grad(loss)(params, x, y)
    opt_state = opt_update(0, grads, opt_state)
    return get_params(opt_state), opt_state, value

# Defining an optimizer in Jax
step_size = 1e-3
opt_init, opt_update, get_params = optimizers.adam(step_size)
opt_state = opt_init(params)

num_epochs = 5
num_classes = 10

In [18]:
def run_mnist_training_loop(num_epochs, opt_state, net_type="MLP"):
    """ Implements a learning loop over epochs. """
    # Initialize placeholder for loggin
    log_acc_train, log_acc_test, train_loss = [], [], []

    # Get the initial set of parameters
    params = get_params(opt_state)

    # Get initial accuracy after random init
    train_acc = accuracy(params, train_loader)
    test_acc = accuracy(params, test_loader)
    log_acc_train.append(train_acc)
    log_acc_test.append(test_acc)

    # Loop over the training epochs
    for epoch in range(num_epochs):
        start_time = time.time()
        for batch_idx, (data, target) in enumerate(train_loader):
            if net_type == "MLP":
                # Flatten the image into 784 vectors for the MLP
                x = jax.numpy.array(data).reshape(data.size(0), 28*28)
            elif net_type == "CNN":
                # No flattening of the input required for the CNN
                x = jax.numpy.array(data)
            y = one_hot(jax.numpy.array(target), num_classes)
            params, opt_state, loss = update(params, x, y, opt_state)
            train_loss.append(loss)

        epoch_time = time.time() - start_time
        train_acc = accuracy(params, train_loader)
        test_acc = accuracy(params, test_loader)
        log_acc_train.append(train_acc)
        log_acc_test.append(test_acc)
        print("Epoch {} | T: {:0.2f} | Train A: {:0.3f} | Test A: {:0.3f}".format(epoch+1, epoch_time,
                                                                    train_acc, test_acc))

    return train_loss, log_acc_train, log_acc_test


train_loss, train_log, test_log = run_mnist_training_loop(num_epochs,
                                                          opt_state,
                                                          net_type="MLP")

Epoch 1 | T: 18.06 | Train A: 0.969 | Test A: 0.964
Epoch 2 | T: 17.27 | Train A: 0.984 | Test A: 0.974
Epoch 3 | T: 17.12 | Train A: 0.989 | Test A: 0.978
Epoch 4 | T: 17.13 | Train A: 0.992 | Test A: 0.980
Epoch 5 | T: 17.15 | Train A: 0.996 | Test A: 0.981


# Нейронная сеть

In [19]:
from jax.experimental import stax
from jax.experimental.stax import (BatchNorm, Conv, Dense, Flatten,
                                   Relu, LogSoftmax)



In [20]:
init_fun, conv_net = stax.serial(Conv(32, (5, 5), (2, 2), padding="SAME"),
                                 BatchNorm(), Relu,
                                 Conv(32, (5, 5), (2, 2), padding="SAME"),
                                 BatchNorm(), Relu,
                                 Conv(10, (3, 3), (2, 2), padding="SAME"),
                                 BatchNorm(), Relu,
                                 Conv(10, (3, 3), (2, 2), padding="SAME"), Relu,
                                 Flatten,
                                 Dense(num_classes),
                                 LogSoftmax)

_, params = init_fun(key, (batch_size, 1, 28, 28))

In [21]:
def accuracy(params, data_loader):
    """ Compute the accuracy for the CNN case (no flattening of input)"""
    acc_total = 0
    for batch_idx, (data, target) in enumerate(data_loader):
        images = np.array(data)
        targets = one_hot(np.array(target), num_classes)

        target_class = np.argmax(targets, axis=1)
        predicted_class = np.argmax(conv_net(params, images), axis=1)
        acc_total += np.sum(predicted_class == target_class)
    return acc_total/len(data_loader.dataset)

def loss(params, images, targets):
    preds = conv_net(params, images)
    return -np.sum(preds * targets)

In [22]:
step_size = 1e-3
opt_init, opt_update, get_params = optimizers.adam(step_size)
opt_state = opt_init(params)
num_epochs = 5

train_loss, train_log, test_log = run_mnist_training_loop(num_epochs,
                                                          opt_state,
                                                          net_type="CNN")

Epoch 1 | T: 20.32 | Train A: 0.970 | Test A: 0.964
Epoch 2 | T: 15.50 | Train A: 0.981 | Test A: 0.975
Epoch 3 | T: 15.49 | Train A: 0.985 | Test A: 0.979
Epoch 4 | T: 15.46 | Train A: 0.987 | Test A: 0.978
Epoch 5 | T: 15.47 | Train A: 0.989 | Test A: 0.982
