<a href="https://colab.research.google.com/github/yblee110/jax-flax-book/blob/main/ch03_5.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install jax==0.4.24
!pip install flax==0.7.5
!pip install optax==0.1.7
!pip install datasets
!pip install orbax-checkpoint ==0.4.4

[31mERROR: Invalid requirement: '==0.4.4'[0m[31m
[0m

In [2]:
import jax
import flax
import optax
import datasets

import jax.numpy as jnp
from datasets import load_dataset
import flax.linen as nn

print("JAX Version : {}".format(jax.__version__)) #출력 JAX Version : 0.4.20
print("FLAX Version : {}".format(flax.__version__)) #출력 FLAX Version : 0.7.5
print("OPTAX Version : {}".format(optax.__version__))#출력 OPTAX Version : 0.1.7

JAX Version : 0.4.24
FLAX Version : 0.7.5
OPTAX Version : 0.1.7


In [3]:
def get_datasets():
  datasets = load_dataset("mnist")
  datasets = datasets.with_format("jax")
  datasets = {
    "train": {
      "image": datasets["train"]["image"][...,None].astype(jnp.float32)/255,
      "label": datasets["train"]["label"],
    },
    "test": {
      "image": datasets["test"]["image"][...,None].astype(jnp.float32)/255,
      "label": datasets["test"]["label"],
    },
  }
  return datasets['train'], datasets['test']


In [4]:
train_ds, test_ds = get_datasets()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [5]:
class CNN(nn.Module):
    num_classes: int


    @nn.compact
    def __call__(self, x, train: bool):
        x = nn.Conv(features=16, kernel_size=(5, 5), strides=(2, 2),
                    padding='VALID')(x)
        x = nn.BatchNorm(use_running_average=not train)(x)
        x = nn.relu(x)
        x = nn.Conv(features=32, kernel_size=(5, 5), strides=(2, 2),
                    padding='VALID')(x)
        x = nn.BatchNorm(use_running_average=not train)(x)
        x = nn.relu(x)
        x = nn.Dropout(rate=0.5, deterministic=not train)(x)
        x = jnp.mean(x, axis=(1, 2))
        x = nn.Dense(features=self.num_classes)(x)
        return x


In [6]:
rng = jax.random.PRNGKey(0)
model = CNN(num_classes=10)
main_key, params_key, dropout_key = jax.random.split(key=rng, num=3)
variables = model.init(params_key, jnp.ones((1, 28, 28, 1)),train=False)
params = variables['params']
batch_stats = variables['batch_stats']

In [7]:
jax.tree_util.tree_map(jnp.shape, variables)

{'batch_stats': {'BatchNorm_0': {'mean': (16,), 'var': (16,)},
  'BatchNorm_1': {'mean': (32,), 'var': (32,)}},
 'params': {'BatchNorm_0': {'bias': (16,), 'scale': (16,)},
  'BatchNorm_1': {'bias': (32,), 'scale': (32,)},
  'Conv_0': {'bias': (16,), 'kernel': (5, 5, 1, 16)},
  'Conv_1': {'bias': (32,), 'kernel': (5, 5, 16, 32)},
  'Dense_0': {'bias': (10,), 'kernel': (32, 10)}}}

In [8]:
def compute_metrics(logits, labels):
  loss = jnp.mean(optax.softmax_cross_entropy(logits, jax.nn.one_hot(labels, num_classes=10)))
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  metrics = {
    'loss': loss,
    'accuracy': accuracy
  }
  return metrics

In [9]:
def create_learning_rate_fn(num_epochs, warmup_epochs, base_learning_rate, steps_per_epoch):
    warmup_fn = optax.linear_schedule(
        init_value=0., end_value=base_learning_rate,
        transition_steps=warmup_epochs * steps_per_epoch)
    cosine_epochs = max(num_epochs - warmup_epochs, 1)

    cosine_fn = optax.cosine_decay_schedule(
        init_value=base_learning_rate,
        decay_steps=cosine_epochs * steps_per_epoch)
    schedule_fn = optax.join_schedules(
        schedules=[warmup_fn, cosine_fn],
        boundaries=[warmup_epochs * steps_per_epoch])

    return schedule_fn

In [10]:
from flax.training import train_state
from typing import Any

train_epoch = 10
warmup_epoch = 2
learning_rate = 0.01
batch_size = 64
eval_batch_size = 100
train_ds_size = len(train_ds['image'])
steps_per_epoch = train_ds_size // batch_size
learning_rate_fn = create_learning_rate_fn(train_epoch, warmup_epoch, learning_rate, steps_per_epoch)

class TrainState(train_state.TrainState):
  batch_stats: Any
  key: jax.random.key

tx = optax.adam(learning_rate=learning_rate_fn)
state = TrainState.create(
  apply_fn=model.apply,
  params=params,
  batch_stats=batch_stats,
  key=dropout_key,
  tx=tx,
)

In [25]:
import numpy as np

config = {'dimenstions' : np.array([24,24])}

ckpt = {'model': state, 'config': config}
ckpt

{'model': TrainState(step=0, apply_fn=<bound method Module.apply of CNN(
     # attributes
     num_classes = 10
 )>, params={'Conv_0': {'kernel': Array([[[[-0.12646194, -0.18698049,  0.02189859, -0.07726611,
            0.05319414,  0.342181  , -0.11351699,  0.00270566,
           -0.4247318 , -0.28328732, -0.10045266,  0.06585782,
           -0.02638764, -0.04056114, -0.12743749,  0.04699072]],
 
         [[ 0.30871987,  0.21376936, -0.09365036,  0.23345868,
            0.40069953, -0.09183456,  0.17887329, -0.13913761,
           -0.10279006,  0.255104  ,  0.23070394, -0.24359593,
           -0.1385247 ,  0.1087406 , -0.15572512, -0.27195057]],
 
         [[-0.16603094, -0.00411084, -0.08336388, -0.27989522,
           -0.32205588,  0.03834669,  0.17229337,  0.06695329,
            0.12174994, -0.21395265,  0.04941479, -0.17655598,
           -0.13041228,  0.14836961, -0.26886654, -0.1559371 ]],
 
         [[-0.13729113, -0.05941404,  0.20874038,  0.10283542,
            0.24327706,

In [26]:
import orbax.checkpoint
from flax.training import orbax_utils
import os
import shutil

In [27]:
ckpt_dir = '/tmp/flax_ckpt'
if not os.path.exists(ckpt_dir):
    os.makedirs(ckpt_dir)

In [29]:
options = orbax.checkpoint.CheckpointManagerOptions(max_to_keep=2, create=True)
checkpoint_manager = orbax.checkpoint.CheckpointManager(
    '/tmp/flax_ckpt/orbax/managed', orbax_checkpointer, options)

In [30]:
import functools
@functools.partial(jax.jit, static_argnums=2)
def train_step(state, batch, learning_rate_fn):
  dropout_train_key = jax.random.fold_in(key=dropout_key,
    data=state.step)
  def loss_fn(params):
    logits, updates = state.apply_fn(
    {'params': params, 'batch_stats': state.batch_stats},
    batch['image'], train=True, mutable=['batch_stats'],
    rngs={'dropout': dropout_train_key})
    loss = jnp.mean(optax.softmax_cross_entropy(
      logits=logits,
      labels=jax.nn.one_hot(batch['label'],num_classes=10)))
    return loss, (logits, updates)
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (_, (logits, updates)), grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  state = state.replace(batch_stats=updates['batch_stats'])
  metrics = compute_metrics(logits, batch['label'])
  lr = learning_rate_fn(state.step)
  metrics['learning_rate'] = lr
  return state, metrics


In [31]:
@jax.jit
def eval_step(state, batch):
  logits = state.apply_fn(
     {'params': state.params, 'batch_stats': state.batch_stats},
     batch['image'], train=False)
  return compute_metrics(logits, batch['label'])

In [32]:
def train_loop(state, train_ds, batch_size, epoch, learning_rate_fn, rng):
  train_ds_size = len(train_ds['image'])
  steps_per_epoch = train_ds_size // batch_size
  perms = jax.random.permutation(rng, train_ds_size)
  perms = perms[:steps_per_epoch * batch_size]
  perms = perms.reshape((steps_per_epoch, batch_size))
  batch_metrics = []
  for perm in perms:
    batch = {k: v[perm, ...] for k, v in train_ds.items()}
    state, metrics = train_step(state, batch, learning_rate_fn)
    batch_metrics.append(metrics)
  training_batch_metrics = jax.device_get(batch_metrics)
  training_epoch_metrics = {k: sum([metrics[k] for metrics in
    training_batch_metrics])/steps_per_epoch
    for k in training_batch_metrics[0]}
  print('EPOCH: %d\nTraining loss: %.4f, accuracy: %.2f' % (epoch,
    training_epoch_metrics['loss'],
    training_epoch_metrics['accuracy'] * 100))
  return state


In [33]:
def eval_loop(state, test_ds, batch_size):
  eval_ds_size = test_ds['image'].shape[0]
  steps_per_epoch = eval_ds_size // batch_size
  batch_metrics = []
  for i in range(steps_per_epoch):
    batch = {k: v[i*batch_size:(i+1)*batch_size, ...] for k, v in test_ds.items()}
    metrics = eval_step(state, batch)
    batch_metrics.append(metrics)
  eval_batch_metrics = jax.device_get(batch_metrics)
  eval_batch_metrics = {
    k: sum([metrics[k] for metrics in eval_batch_metrics])/steps_per_epoch
    for k in eval_batch_metrics[0]}

  print('    Eval loss: %.4f, accuracy: %.2f' % (eval_batch_metrics['loss'], eval_batch_metrics['accuracy'] * 100))

In [35]:
options = orbax.checkpoint.CheckpointManagerOptions(max_to_keep=2,
  create=True)
checkpoint_manager = orbax.checkpoint.CheckpointManager(
  'tmp/managed', orbax_checkpointer, options)

for epoch in range(train_epoch):
  rng, key = jax.random.split(rng)
  state = train_loop(state, train_ds, batch_size, epoch,
    learning_rate_fn, rng)
  eval_loop(state, test_ds, eval_batch_size)
  checkpoint_manager.save(epoch, ckpt,
    save_kwargs={'save_args':save_args})

EPOCH: 0
Training loss: 0.2486, accuracy: 92.36
    Eval loss: 0.1688, accuracy: 95.09


ValueError: Checkpoint path should be absolute. Got tmp/managed/0.orbax-checkpoint-tmp-1710600407521790/default.orbax-checkpoint-tmp-1710600407527959

In [None]:
raw_restored = orbax_checkpointer.restore('tmp/orbax/single_save')
raw_restored