<a href="https://colab.research.google.com/github/yblee110/jax-flax-book/blob/main/ch03_2_3_%ED%95%99%EC%8A%B5%EB%A5%A0_%EC%8A%A4%EC%BC%80%EC%A4%84%EB%A7%81_%EC%A0%81%EC%9A%A9.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install datasets
!pip install orbax-checkpoint ==0.4.4

Collecting datasets
  Downloading datasets-2.20.0-py3-none-any.whl (547 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/547.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m153.6/547.8 kB[0m [31m4.4 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m547.8/547.8 kB[0m [31m8.4 MB/s[0m eta [36m0:00:00[0m
Collecting pyarrow>=15.0.0 (from datasets)
  Downloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl (39.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m39.9/39.9 MB[0m [31m31.0 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m12.2 MB/s[0m eta [36m0:00:00[0m
Collecting requests>=2.32.2 (from datasets)
  Downloading requests-2.32.3-py3-none-any

In [2]:
import jax
import flax
import optax
import datasets

import jax.numpy as jnp
from datasets import load_dataset
import flax.linen as nn

print("JAX Version : {}".format(jax.__version__)) #출력 JAX Version : 0.4.20
print("FLAX Version : {}".format(flax.__version__)) #출력 FLAX Version : 0.7.5
print("OPTAX Version : {}".format(optax.__version__))#출력 OPTAX Version : 0.1.7

JAX Version : 0.4.26
FLAX Version : 0.8.4
OPTAX Version : 0.2.2


In [3]:
def get_datasets():
  datasets = load_dataset("mnist")
  datasets = datasets.with_format("jax")
  datasets = {
    "train": {
      "image": datasets["train"]["image"][...,None].astype(jnp.float32)/255,
      "label": datasets["train"]["label"],
    },
    "test": {
      "image": datasets["test"]["image"][...,None].astype(jnp.float32)/255,
      "label": datasets["test"]["label"],
    },
  }
  return datasets['train'], datasets['test']


In [4]:
train_ds, test_ds = get_datasets()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Downloading builder script:   0%|          | 0.00/3.98k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/6.83k [00:00<?, ?B/s]

The repository for mnist contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/mnist.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N] y


Downloading data:   0%|          | 0.00/9.91M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/28.9k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.65M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.54k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/60000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/10000 [00:00<?, ? examples/s]

In [5]:
class CNN(nn.Module):
    num_classes: int


    @nn.compact
    def __call__(self, x, train: bool):
        x = nn.Conv(features=16, kernel_size=(5, 5), strides=(2, 2),
                    padding='VALID')(x)
        x = nn.BatchNorm(use_running_average=not train)(x)
        x = nn.relu(x)
        x = nn.Conv(features=32, kernel_size=(5, 5), strides=(2, 2),
                    padding='VALID')(x)
        x = nn.BatchNorm(use_running_average=not train)(x)
        x = nn.relu(x)
        x = nn.Dropout(rate=0.5, deterministic=not train)(x)
        x = jnp.mean(x, axis=(1, 2))
        x = nn.Dense(features=self.num_classes)(x)
        return x


In [6]:
rng = jax.random.PRNGKey(0)
model = CNN(num_classes=10)
main_key, params_key, dropout_key = jax.random.split(key=rng, num=3)
variables = model.init(params_key, jnp.ones((1, 28, 28, 1)),train=False)
params = variables['params']
batch_stats = variables['batch_stats']

In [7]:
jax.tree_util.tree_map(jnp.shape, variables)

{'batch_stats': {'BatchNorm_0': {'mean': (16,), 'var': (16,)},
  'BatchNorm_1': {'mean': (32,), 'var': (32,)}},
 'params': {'BatchNorm_0': {'bias': (16,), 'scale': (16,)},
  'BatchNorm_1': {'bias': (32,), 'scale': (32,)},
  'Conv_0': {'bias': (16,), 'kernel': (5, 5, 1, 16)},
  'Conv_1': {'bias': (32,), 'kernel': (5, 5, 16, 32)},
  'Dense_0': {'bias': (10,), 'kernel': (32, 10)}}}

In [8]:
def compute_metrics(logits, labels):
  loss = jnp.mean(optax.softmax_cross_entropy(logits, jax.nn.one_hot(labels, num_classes=10)))
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  metrics = {
    'loss': loss,
    'accuracy': accuracy
  }
  return metrics

In [9]:
def create_learning_rate_fn(num_epochs, warmup_epochs, base_learning_rate, steps_per_epoch):
    warmup_fn = optax.linear_schedule(
        init_value=0., end_value=base_learning_rate,
        transition_steps=warmup_epochs * steps_per_epoch)
    cosine_epochs = max(num_epochs - warmup_epochs, 1)

    cosine_fn = optax.cosine_decay_schedule(
        init_value=base_learning_rate,
        decay_steps=cosine_epochs * steps_per_epoch)
    schedule_fn = optax.join_schedules(
        schedules=[warmup_fn, cosine_fn],
        boundaries=[warmup_epochs * steps_per_epoch])

    return schedule_fn

In [10]:
from flax.training import train_state
from typing import Any

train_epoch = 10
warmup_epoch = 2
learning_rate = 0.01
batch_size = 64
eval_batch_size = 100
train_ds_size = len(train_ds['image'])
steps_per_epoch = train_ds_size // batch_size
learning_rate_fn = create_learning_rate_fn(train_epoch, warmup_epoch, learning_rate, steps_per_epoch)

class TrainState(train_state.TrainState):
  batch_stats: Any
  key: jax.random.key

tx = optax.adam(learning_rate=learning_rate_fn)
state = TrainState.create(
  apply_fn=model.apply,
  params=params,
  batch_stats=batch_stats,
  key=dropout_key,
  tx=tx,
)

In [11]:
import orbax.checkpoint
from flax.training import orbax_utils
import os
import shutil

In [12]:
train_path = os.getenv('HOME') + '/cnn/'


ckpt_mgr = orbax.checkpoint.CheckpointManager(
       train_path,
       orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler()),
       orbax.checkpoint.CheckpointManagerOptions(
           create=not (os.path.isdir(train_path)),
           max_to_keep=3,
           step_prefix="model_epoch",
       ),
   )


In [13]:
import functools
@functools.partial(jax.jit, static_argnums=2)
def train_step(state, batch, learning_rate_fn):
  dropout_train_key = jax.random.fold_in(key=dropout_key,
    data=state.step)
  def loss_fn(params):
    logits, updates = state.apply_fn(
    {'params': params, 'batch_stats': state.batch_stats},
    batch['image'], train=True, mutable=['batch_stats'],
    rngs={'dropout': dropout_train_key})
    loss = jnp.mean(optax.softmax_cross_entropy(
      logits=logits,
      labels=jax.nn.one_hot(batch['label'],num_classes=10)))
    return loss, (logits, updates)
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (_, (logits, updates)), grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  state = state.replace(batch_stats=updates['batch_stats'])
  metrics = compute_metrics(logits, batch['label'])
  lr = learning_rate_fn(state.step)
  metrics['learning_rate'] = lr
  return state, metrics


In [14]:
@jax.jit
def eval_step(state, batch):
  logits = state.apply_fn(
     {'params': state.params, 'batch_stats': state.batch_stats},
     batch['image'], train=False)
  return compute_metrics(logits, batch['label'])

In [15]:
def train_loop(state, train_ds, batch_size, epoch, learning_rate_fn, rng):
  train_ds_size = len(train_ds['image'])
  steps_per_epoch = train_ds_size // batch_size
  perms = jax.random.permutation(rng, train_ds_size)
  perms = perms[:steps_per_epoch * batch_size]
  perms = perms.reshape((steps_per_epoch, batch_size))
  batch_metrics = []
  for perm in perms:
    batch = {k: v[perm, ...] for k, v in train_ds.items()}
    state, metrics = train_step(state, batch, learning_rate_fn)
    batch_metrics.append(metrics)
  training_batch_metrics = jax.device_get(batch_metrics)
  training_epoch_metrics = {k: sum([metrics[k] for metrics in
    training_batch_metrics])/steps_per_epoch
    for k in training_batch_metrics[0]}
  print('EPOCH: %d\nTraining loss: %.4f, accuracy: %.2f' % (epoch,
    training_epoch_metrics['loss'],
    training_epoch_metrics['accuracy'] * 100))
  return state


In [16]:
def eval_loop(state, test_ds, batch_size):
  eval_ds_size = test_ds['image'].shape[0]
  steps_per_epoch = eval_ds_size // batch_size
  batch_metrics = []
  for i in range(steps_per_epoch):
    batch = {k: v[i*batch_size:(i+1)*batch_size, ...] for k, v in test_ds.items()}
    metrics = eval_step(state, batch)
    batch_metrics.append(metrics)
  eval_batch_metrics = jax.device_get(batch_metrics)
  eval_batch_metrics = {
    k: sum([metrics[k] for metrics in eval_batch_metrics])/steps_per_epoch
    for k in eval_batch_metrics[0]}

  print('    Eval loss: %.4f, accuracy: %.2f' % (eval_batch_metrics['loss'], eval_batch_metrics['accuracy'] * 100))

In [17]:
for epoch in range(train_epoch):
  rng, key = jax.random.split(rng)
  state = train_loop(state, train_ds, batch_size, epoch,
    learning_rate_fn, rng)
  eval_loop(state, test_ds, eval_batch_size)
  save_args = orbax_utils.save_args_from_target(state)
  ckpt_mgr.save(epoch, state, save_kwargs={"save_args": save_args})

EPOCH: 0
Training loss: 0.9390, accuracy: 72.74
    Eval loss: 0.3019, accuracy: 90.66
EPOCH: 1
Training loss: 0.2484, accuracy: 92.45
    Eval loss: 0.1795, accuracy: 94.44
EPOCH: 2
Training loss: 0.1893, accuracy: 94.27
    Eval loss: 0.1090, accuracy: 96.86
EPOCH: 3
Training loss: 0.1559, accuracy: 95.26
    Eval loss: 0.0948, accuracy: 97.32
EPOCH: 4
Training loss: 0.1393, accuracy: 95.73
    Eval loss: 0.1138, accuracy: 96.61
EPOCH: 5
Training loss: 0.1211, accuracy: 96.20
    Eval loss: 0.0641, accuracy: 98.09
EPOCH: 6
Training loss: 0.1069, accuracy: 96.67
    Eval loss: 0.0626, accuracy: 98.14
EPOCH: 7
Training loss: 0.0972, accuracy: 96.98
    Eval loss: 0.0504, accuracy: 98.49
EPOCH: 8
Training loss: 0.0882, accuracy: 97.24
    Eval loss: 0.0479, accuracy: 98.55
EPOCH: 9
Training loss: 0.0841, accuracy: 97.41
    Eval loss: 0.0475, accuracy: 98.56


In [18]:
raw_restored = ckpt_mgr.restore('9')
raw_restored

{'batch_stats': {'BatchNorm_0': {'mean': array([ 0.59678304,  0.34670404, -0.32364425, -1.5403948 , -0.41289315,
          -0.9029092 , -1.0194184 , -1.4013318 , -0.2755111 , -0.40960672,
          -0.2929568 , -1.3420625 ,  0.9907112 , -1.3641517 , -0.05670249,
          -0.9666436 ], dtype=float32),
   'var': array([2.754574  , 0.4765729 , 0.79068047, 3.139733  , 1.9464167 ,
          3.3058364 , 3.4047148 , 3.8393342 , 1.8070192 , 2.9429402 ,
          0.3669262 , 3.5045764 , 2.1112766 , 2.0194216 , 0.6417175 ,
          2.8159597 ], dtype=float32)},
  'BatchNorm_1': {'mean': array([-3.7710893 ,  1.0635827 , -5.265705  , -0.9515563 ,  5.7551737 ,
          -7.405269  , -0.0529795 , -0.88988537, -2.9906244 ,  5.9144897 ,
          -1.8900927 , -3.976402  , -2.58967   , -6.522824  , -2.9571164 ,
           1.7212024 ,  0.36768922,  0.04797452, -0.2436241 , -7.0267615 ,
           5.0057926 , -4.0950956 ,  1.439795  , -5.6158085 , -4.036963  ,
           4.252638  , -0.54707086, -4.911

In [19]:
step = ckpt_mgr.latest_step()  # step = 4
ckpt_mgr.restore(step)

{'batch_stats': {'BatchNorm_0': {'mean': array([ 0.59678304,  0.34670404, -0.32364425, -1.5403948 , -0.41289315,
          -0.9029092 , -1.0194184 , -1.4013318 , -0.2755111 , -0.40960672,
          -0.2929568 , -1.3420625 ,  0.9907112 , -1.3641517 , -0.05670249,
          -0.9666436 ], dtype=float32),
   'var': array([2.754574  , 0.4765729 , 0.79068047, 3.139733  , 1.9464167 ,
          3.3058364 , 3.4047148 , 3.8393342 , 1.8070192 , 2.9429402 ,
          0.3669262 , 3.5045764 , 2.1112766 , 2.0194216 , 0.6417175 ,
          2.8159597 ], dtype=float32)},
  'BatchNorm_1': {'mean': array([-3.7710893 ,  1.0635827 , -5.265705  , -0.9515563 ,  5.7551737 ,
          -7.405269  , -0.0529795 , -0.88988537, -2.9906244 ,  5.9144897 ,
          -1.8900927 , -3.976402  , -2.58967   , -6.522824  , -2.9571164 ,
           1.7212024 ,  0.36768922,  0.04797452, -0.2436241 , -7.0267615 ,
           5.0057926 , -4.0950956 ,  1.439795  , -5.6158085 , -4.036963  ,
           4.252638  , -0.54707086, -4.911