# Getting Started
https://flax.readthedocs.io/en/latest/getting_started.html

Flax Linen API を使用してシンプルなCNNを構築し，MNISTデータセットを学習する．

In [1]:
!nvidia-smi -L

GPU 0: NVIDIA GeForce RTX 3090 (UUID: GPU-7643832d-1e3a-0941-db93-db6078bd9220)


In [2]:
!nvidia-smi

Sun Sep  4 20:12:17 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.43.04    Driver Version: 515.43.04    CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  On   | 00000000:01:00.0 Off |                  N/A |
| 30%   38C    P8    26W / 350W |     36MiB / 24576MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

## 1. Imports

In [3]:
import jax
import jax.numpy as jnp

from flax import linen as nn
from flax.training import train_state

import numpy as np
import optax
import tensorflow as tf
import tensorflow_datasets as tfds

tf.config.experimental.set_visible_devices([], "GPU")

## 2. Define network

In [4]:
class CNN(nn.Module):
    '''A simple CNN model.'''

    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=32, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(features=64, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1)) # flatten
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=10)(x)
        return x

## 3. Define loss
`optax.softmax_cross_entropy()` の入力である `logit`, `labels` の形状は共に `[batch, num_classes]` の必要がある

In [5]:
def cross_entropy_loss(*, logits, labels):
    labels_onehot = jax.nn.one_hot(labels, num_classes=10)
    return optax.softmax_cross_entropy(logits=logits, labels=labels_onehot).mean()

## 4. Metric computation

In [6]:
def compute_metrics(*, logits, labels):
    loss = cross_entropy_loss(logits=logits, labels=labels)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    metrics = {
        'loss': loss,
        'accuracy': accuracy,
    }
    return metrics

## 5. Loading data

In [7]:
def get_datasets():
  """Load MNIST train and test datasets into memory."""
  ds_builder = tfds.builder('mnist')
  ds_builder.download_and_prepare()
  train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
  test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
  train_ds['image'] = jnp.float32(train_ds['image']) / 255.
  test_ds['image'] = jnp.float32(test_ds['image']) / 255.
  return train_ds, test_ds

## 6. Create train state
Flaxの一般的なパターンは，ステップ番号，パラメータ，オプティマイザの状態などを含めた学習状態全体を表す1つのデータクラスを作成することである．また，このデータクラスにオプティマイザやモデルを追加することで，`train_step()`のような関数に渡す引数が1つで済むという利点もある．これはよくあるパターンであるため，`flax.training.train_state.TrainState` というクラスを提供しており，ほとんど基本的なユースケースに対応している．追跡データを追加する場合は，このクラスをサブクラス化することで可能である．

In [8]:
def create_train_state(rng, learning_rate, momentum):
    cnn = CNN()
    params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params']
    tx = optax.sgd(learning_rate, momentum)
    return train_state.TrainState.create(
        apply_fn=cnn.apply, params=params, tx=tx
    )

## 7. Training step
- `Module.apply` を使用して，パラメータと入力画像のバッチを指定し，NNを評価する
- `cross_entropy_loss` を計算する
- `jax.value_and_grad` を使用して損失関数とその勾配を評価する
- モデルのパラメータを更新するために，オプティマイザに勾配の `pytree` を適用する
- `compute_metrics` を用いてメトリクスを計算する

In [20]:
@jax.jit
def train_step(state, batch):
  """Train for a single step."""
  def loss_fn(params):
    logits = CNN().apply({'params': params}, batch['image'])
    loss = cross_entropy_loss(logits=logits, labels=batch['label'])
    return loss, logits
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (_, logits), grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  metrics = compute_metrics(logits=logits, labels=batch['label'])
  return state, metrics

## 8. Evaluation step

In [10]:
@jax.jit
def eval_step(params, batch):
    logits = CNN().apply({'params': params}, batch['image'])
    return compute_metrics(logits=logits, labels=batch['label'])

## 9. Train function
- 疑似乱数生成鍵をパラメータとする `jax.random.permutation` を使用して，各エポックの前に学習データーをシャッフルする
- 各バッチに対して最適化ステップを実行する
- `jax.device_get` でデバイスからトレーニングメトリクスを取得し，エポックないの各バッチ間の平均を計算する
- パラメータを更新したオプティマイザを，トレーニングの損失と精度のメトリクスと共に返す

In [11]:
def train_epoch(state, train_ds, batch_size, epoch, rng):
    '''Train for a single epoch.'''
    train_ds_size = len(train_ds['image'])
    steps_per_epoch = train_ds_size // batch_size

    perms = jax.random.permutation(rng, train_ds_size)
    perms = perms[:steps_per_epoch * batch_size]
    perms = perms.reshape((steps_per_epoch, batch_size))
    batch_metrics = []
    for perm in perms:
        batch = {k: v[perm, ...] for k, v in train_ds.items()}
        state, metrics = train_step(state, batch)
        batch_metrics.append(metrics)

    batch_metrics_np = jax.device_get(batch_metrics)
    epoch_metrics_np = {
        k: np.mean([metrics[k] for metrics in batch_metrics_np])
        for k in batch_metrics_np[0]
    }

    print(f"train epoch: {epoch}, loss: {epoch_metrics_np['loss']:.4f}, accuracy: {100 * epoch_metrics_np['accuracy']:.2f}")

    return state

## 10. Eval function
- `jax.device_get` でデバイスから評価指標を取得する
- JAX pytree に格納されているメトリクスデータをコピーする

In [12]:
def eval_model(params, test_ds):
    metrics = eval_step(params, test_ds)
    metrics = jax.device_get(metrics)
    summary = jax.tree_util.tree_map(lambda x: x.item(), metrics)
    return summary['loss'], summary['accuracy']

## 11. Download data

In [13]:
train_ds, test_ds = get_datasets()

## 12. Seed randomness

In [14]:
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)

## 13. Initialize train state

In [15]:
learning_rate = 0.1
momentum = 0.9

In [16]:
state = create_train_state(init_rng, learning_rate, momentum)
del init_rng

## 14. Train and evaluate

In [17]:
num_epochs = 10
batch_size = 32

In [21]:
for epoch in range(1, num_epochs + 1):
    rng, input_rng = jax.random.split(rng)
    state = train_epoch(state, train_ds, batch_size, epoch, input_rng)
    test_loss, test_accuracy = eval_model(state.params, test_ds)
    print(f'test epoch: {epoch}, loss: {test_loss:.2f}, accuracy: {100 * test_accuracy:.2f}')

train epoch: 1, loss: 0.1347, accuracy: 95.89
test epoch: 1, loss: 0.12, accuracy: 96.17
train epoch: 2, loss: 0.0489, accuracy: 98.57
test epoch: 2, loss: 0.04, accuracy: 98.81
train epoch: 3, loss: 0.0347, accuracy: 98.95
test epoch: 3, loss: 0.04, accuracy: 98.80
train epoch: 4, loss: 0.0292, accuracy: 99.06
test epoch: 4, loss: 0.04, accuracy: 98.89
train epoch: 5, loss: 0.0217, accuracy: 99.34
test epoch: 5, loss: 0.04, accuracy: 98.78
train epoch: 6, loss: 0.0172, accuracy: 99.45
test epoch: 6, loss: 0.04, accuracy: 99.03
train epoch: 7, loss: 0.0165, accuracy: 99.50
test epoch: 7, loss: 0.04, accuracy: 99.03
train epoch: 8, loss: 0.0149, accuracy: 99.53
test epoch: 8, loss: 0.04, accuracy: 98.82
train epoch: 9, loss: 0.0106, accuracy: 99.67
test epoch: 9, loss: 0.04, accuracy: 99.05
train epoch: 10, loss: 0.0082, accuracy: 99.73
test epoch: 10, loss: 0.06, accuracy: 98.64


## その他
- [OOM について](https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html)
- [PRNGKey について](https://jax.readthedocs.io/en/latest/jax-101/05-random-numbers.html)
- [pytree について](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html)