# Deep Learning 教程

主要参考：

- [Training a Simple Neural Network, with PyTorch Data Loading](https://jax.readthedocs.io/en/latest/notebooks/Neural_Network_and_Data_Loading.html)
- [From PyTorch to JAX: towards neural net frameworks that purify stateful code](https://sjmielke.com/jax-purify.htm)
- [JAX vs Tensorflow vs Pytorch: Building a Variational Autoencoder (VAE)](https://theaisummer.com/jax-tensorflow-pytorch/)

首先训练一个简单的神经网络。

## Training a Simple Neural Network

首先使用JAX，在MNIST上指定并训练一个简单的MLP。将使用PyTorch的数据加载API加载图像和标签（因为它非常棒，我们不需要另一个数据加载库）。

当然，可以将JAX和与NumPy兼容的任何API一起使用，以使模型更加即插即用。本小节，仅出于说明目的，将不使用任何神经网络库或特殊的API来构建模型。

In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

首先，定义超参数

In [2]:
# A helper function to randomly initialize weights and biases
# for a dense neural network layer
def random_layer_params(m, n, key, scale=1e-2):
    w_key, b_key = random.split(key)
    return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
    keys = random.split(key, len(sizes))
    return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

layer_sizes = [784, 512, 512, 10]
param_scale = 0.1
step_size = 0.01
num_epochs = 8
batch_size = 128
n_targets = 10
params = init_network_params(layer_sizes, random.PRNGKey(0))



然后定义预测函数。请注意，是为单个图像示例定义的。将使用JAX的vmap功能自动处理mini-batchs，而不会影响性能。

关于其中用到的logesumexp，可以参考[这里](https://blog.feedly.com/tricks-of-the-trade-logsumexp/)

In [3]:
from jax.scipy.special import logsumexp

def relu(x):
    return jnp.maximum(0, x)

def predict(params, image):
    # per-example predictions
    activations = image
    # 最后一层前的几层
    for w, b in params[:-1]:
        outputs = jnp.dot(w, activations) + b
        activations = relu(outputs)
    # 最后一层
    final_w, final_b = params[-1]
    logits = jnp.dot(final_w, activations) + final_b
    return logits - logsumexp(logits)

让我们检查一下适用于单个图像的预测功能。

In [4]:
# This works on single examples
random_flattened_image = random.normal(random.PRNGKey(1), (28 * 28,))
preds = predict(params, random_flattened_image)
print(preds.shape)

(10,)


In [5]:
# Doesn't work with a batch
random_flattened_images = random.normal(random.PRNGKey(1), (10, 28 * 28))
try:
    preds = predict(params, random_flattened_images)
except TypeError:
    print('Invalid shapes!')

Invalid shapes!


In [6]:
# Let's upgrade it to handle batches using `vmap`

# Make a batched version of the `predict` function
batched_predict = vmap(predict, in_axes=(None, 0))

# `batched_predict` has the same call signature as `predict`
batched_preds = batched_predict(params, random_flattened_images)
print(batched_preds.shape)

(10, 10)


至此，我们拥有定义神经网络并对其进行训练所需的所有要素。建立了auto-batched 版本的predict，可以在损失函数中使用。后面使用grad能计算loss相对于神经网络参数的导数。还能，使用jit能加速一切。

下面定义效用和损失函数

In [7]:
def one_hot(x, k, dtype=jnp.float32):
    """Create a one-hot encoding of x of size k."""
    return jnp.array(x[:, None] == jnp.arange(k), dtype)
  
def accuracy(params, images, targets):
    target_class = jnp.argmax(targets, axis=1)
    predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
    return jnp.mean(predicted_class == target_class)

def loss(params, images, targets):
    preds = batched_predict(params, images)
    return -jnp.mean(preds * targets)

@jit
def update(params, x, y):
    grads = grad(loss)(params, x, y)
    return [(w - step_size * dw, b - step_size * db) for (w, b), (dw, db) in zip(params, grads)]

使用PyTorch加载数据。JAX专注于程序转换和加速器支持的NumPy，因此JAX库中不包括数据加载或修改。已经有很多出色的数据加载器，所以使用它们而不是重新发明。这里使用PyTorch的数据加载器，并做一个很小的填充以使其与NumPy数组一起使用。

In [8]:
import numpy as np
from torch.utils import data
from torchvision.datasets import MNIST

def numpy_collate(batch):
    if isinstance(batch[0], np.ndarray):
        return np.stack(batch)
    elif isinstance(batch[0], (tuple,list)):
        transposed = zip(*batch)
        return [numpy_collate(samples) for samples in transposed]
    else:
        return np.array(batch)

class NumpyLoader(data.DataLoader):
    def __init__(self, dataset, batch_size=1,
                shuffle=False, sampler=None,
                batch_sampler=None, num_workers=0,
                pin_memory=False, drop_last=False,
                timeout=0, worker_init_fn=None):
        super(self.__class__, self).__init__(dataset,
                                             batch_size=batch_size,
                                             shuffle=shuffle,
                                             sampler=sampler,
                                             batch_sampler=batch_sampler,
                                             num_workers=num_workers, 
                                             collate_fn=numpy_collate,
                                             pin_memory=pin_memory,
                                             drop_last=drop_last,
                                             timeout=timeout,
                                             worker_init_fn=worker_init_fn)
class FlattenAndCast(object):
    def __call__(self, pic):
        return np.ravel(np.array(pic, dtype=jnp.float32))

In [9]:
# Define our dataset, using torch datasets
# 把图像展开成一维
mnist_dataset = MNIST('/tmp/mnist/', download=True, transform=FlattenAndCast())
training_generator = NumpyLoader(mnist_dataset, batch_size=batch_size, num_workers=0)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Using downloaded and verified file: /tmp/mnist/MNIST/raw/train-images-idx3-ubyte.gz
Extracting /tmp/mnist/MNIST/raw/train-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Using downloaded and verified file: /tmp/mnist/MNIST/raw/train-labels-idx1-ubyte.gz
Extracting /tmp/mnist/MNIST/raw/train-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz


100.0%


Extracting /tmp/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz


112.7%
  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


Extracting /tmp/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw

Processing...
Done!


In [10]:
# Get the full train dataset (for checking accuracy while training)
train_images = np.array(mnist_dataset.train_data).reshape(len(mnist_dataset.train_data), -1)
train_labels = one_hot(np.array(mnist_dataset.train_labels), n_targets)

# Get full test dataset
mnist_dataset_test = MNIST('/tmp/mnist/', download=True, train=False)
test_images = jnp.array(mnist_dataset_test.test_data.numpy().reshape(len(mnist_dataset_test.test_data), -1), dtype=jnp.float32)
test_labels = one_hot(np.array(mnist_dataset_test.test_labels), n_targets)



训练循环

In [11]:
import time

for epoch in range(num_epochs):
    start_time = time.time()
    for x, y in training_generator:
        y = one_hot(y, n_targets)
        params = update(params, x, y)
    epoch_time = time.time() - start_time
    
    train_acc = accuracy(params, train_images, train_labels)
    test_acc = accuracy(params, test_images, test_labels)
    print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
    print("Training set accuracy {}".format(train_acc))
    print("Test set accuracy {}".format(test_acc))

Epoch 0 in 5.54 sec
Training set accuracy 0.9158166646957397
Test set accuracy 0.919700026512146
Epoch 1 in 4.34 sec
Training set accuracy 0.9371833205223083
Test set accuracy 0.9384999871253967
Epoch 2 in 4.59 sec
Training set accuracy 0.9492166638374329
Test set accuracy 0.9470000267028809
Epoch 3 in 4.67 sec
Training set accuracy 0.9567166566848755
Test set accuracy 0.9532999992370605
Epoch 4 in 4.92 sec
Training set accuracy 0.9630500078201294
Test set accuracy 0.957099974155426
Epoch 5 in 4.40 sec
Training set accuracy 0.9674500226974487
Test set accuracy 0.9617999792098999
Epoch 6 in 4.50 sec
Training set accuracy 0.97079998254776
Test set accuracy 0.9652000069618225
Epoch 7 in 4.47 sec
Training set accuracy 0.9737833142280579
Test set accuracy 0.9672999978065491


现在，我们已经使用了整个JAX API：grad求导，jit加速，vmap自动矢量化。使用NumPy来指定所有的计算，并从PyTorch借用了出色的数据加载器，然后运行了整个程序。

## From PyTorch to JAX

在开始基于JAX的神经网络库之前，有必要了解下 JAX 和 PyTorch之间的区别。

关注JAX，主要是因为其高性能的Numpy以及自动微分，即计算某些损失函数相对于输入参数的梯度。但是从PyTorch或Tensorflow 2转到JAX不是一个小变化：建立计算的基本方法，以及更重要的进行反向传播的方式在他们之间是有根本不同的！当前向计算时，PyTorch会建立一个计算图，然后其backward()对某个“结果”节点进行一次调用，然后使用结果节点相对于该中间节点的梯度来加强图中的每个中间节点。而JAX则将计算表示为Python函数，并通过grad()对其进行转换来提供梯度函数，您可以像普通计算函数一样对梯度函数进行求值-输出相对于函数输入的第一个参数（默认）的梯度：

![](pictures/comparison_small.png)

当然，这会对如何在两个框架中编写代码和建立模型有影响。因此，当习惯于PyTorch或Tensorflow 2中的自动微分且使用有状态对象时，转到JAX可能会有些不习惯。

如果查看flax，trax或haiku的库，看到某些神经网络的示例，看起来与其他任何框架也不是太不相似，都是定义一些层，运行一些trainers... 但是里面究竟发生了什么？从小的numpy函数到训练大型分层神经网络的路线是什么？还是值得了解下的。

本小节我们将：

1. 快速回顾基于反向自动微分框架PyTorch的有状态LSTM-LM的实现，
2. 了解PyTorch风格的编码如何依赖于可变状态，了解不可变的纯函数并在JAX中构建（纯）zappy单层代码，
3. 将各个参数注册为pytree节点，逐步将它们从单个参数扩展到中等大小的模块，
4. 通过构建精美的脚手架并控制上下文以提取初始化参数来净化函数，从而消除增长的痛苦
5. 意识到我们可以使用DeepMind的transform机制在诸如DeepMind的haiku框架中轻松实现这一目标。

### An LSTM-LM in PyTorch

实现一个LSTMCell

In [None]:
import torch

class LSTMCell(torch.nn.Module):
    def __init__(self, in_dim, out_dim):
        super(LSTMCell, self).__init__()
        self.weight_ih = torch.nn.Parameter(torch.rand(4*out_dim, in_dim))
        self.weight_hh = torch.nn.Parameter(torch.rand(4*out_dim, out_dim))
        self.bias = torch.nn.Parameter(torch.zeros(4*out_dim,))
        
    def forward(self, inputs, h, c):
        ifgo = self.weight_ih @ inputs + self.weight_hh @ h + self.bias
        i, f, g, o = torch.chunk(ifgo, 4)
        i = torch.sigmoid(i)
        f = torch.sigmoid(f)
        g = torch.tanh(g)
        o = torch.sigmoid(o)
        new_c = f * c + i * g
        new_h = o * torch.tanh(new_c)
        return (new_h, new_c)

## JAX下的深度学习库

根据 https://github.com/google/jax#neural-network-libraries 的说明，有多个Google的研究组使用JAX开发了神经网络库，比如 [Flax](https://github.com/google/flax)，[Trax](https://github.com/google/trax)，[Objax](https://github.com/google/objax)；DeepMind也有[基于JAX的开源生态](https://deepmind.com/blog/article/using-jax-to-accelerate-our-research)，包括[Haiku](https://github.com/deepmind/dm-haiku)，[Optax](https://github.com/deepmind/optax)，[RLax](https://github.com/deepmind/rlax)，以及[Chex](https://github.com/deepmind/chex)

如果想要一个具有示例和操作指南的功能齐全的神经网络训练库，请尝试 Flax；另一个选择是 Trax；它是一个基于组合器的框架，专注于易用性和端到端的单命令示例，特别是针对序列模型和强化学习的示例；Objax是具有PyTorch式界面的简约的面向对象框架。

Haiku 用于神经网络的模块；Optax梯度处理和优化；RLax为RL算法；Chex用于可靠代码和测试。

个人认为选择 Flax 开始更容易上手。

