<a href="https://colab.research.google.com/github/yblee110/jax-flax-book/blob/main/ch03_2_1_%EB%B0%B0%EC%B9%98_%EC%A0%95%EA%B7%9C%ED%99%94_%EC%A0%81%EC%9A%A9.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install datasets

Collecting datasets
  Downloading datasets-2.20.0-py3-none-any.whl (547 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m547.8/547.8 kB[0m [31m4.9 MB/s[0m eta [36m0:00:00[0m
Collecting pyarrow>=15.0.0 (from datasets)
  Downloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl (39.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m39.9/39.9 MB[0m [31m32.3 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m14.4 MB/s[0m eta [36m0:00:00[0m
Collecting requests>=2.32.2 (from datasets)
  Downloading requests-2.32.3-py3-none-any.whl (64 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m64.9/64.9 kB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0m
Collecting xxhash (from datasets)
  Downloading xxhash-3.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1

In [2]:

import jax
import flax
import optax
import datasets

print("JAX Version : {}".format(jax.__version__)) #출력 JAX Version : 0.4.20
print("FLAX Version : {}".format(flax.__version__)) #출력 FLAX Version : 0.7.5
print("OPTAX Version : {}".format(optax.__version__))#출력 OPTAX Version : 0.1.7

JAX Version : 0.4.26
FLAX Version : 0.8.4
OPTAX Version : 0.2.2


In [3]:
import jax.numpy as jnp
from datasets import load_dataset


def get_datasets():
  datasets = load_dataset("mnist")
  datasets = datasets.with_format("jax")
  datasets = {
    "train": {
      "image": datasets["train"]["image"][...,None].astype(jnp.float32)/255,
      "label": datasets["train"]["label"],
    },
    "test": {
      "image": datasets["test"]["image"][...,None].astype(jnp.float32)/255,
      "label": datasets["test"]["label"],
    },
  }
  return datasets['train'], datasets['test']


In [4]:
train_ds, test_ds = get_datasets()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Downloading builder script:   0%|          | 0.00/3.98k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/6.83k [00:00<?, ?B/s]

The repository for mnist contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/mnist.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N] y


Downloading data:   0%|          | 0.00/9.91M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/28.9k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.65M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.54k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/60000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/10000 [00:00<?, ? examples/s]

In [5]:
import flax.linen as nn

class CNN(nn.Module):
  num_classes: int


  @nn.compact
  def __call__(self, x, train: bool):
      x = nn.Conv(features=16, kernel_size=(5, 5), strides=(2, 2),
                  padding='VALID')(x)
      x = nn.BatchNorm(use_running_average=not train)(x)
      x = nn.relu(x)
      x = nn.Conv(features=32, kernel_size=(5, 5), strides=(2, 2),
                  padding='VALID')(x)
      x = nn.BatchNorm(use_running_average=not train)(x)
      x = nn.relu(x)
      x = jnp.mean(x, axis=(1, 2))
      x = nn.Dense(features=self.num_classes)(x)
      return x


In [6]:
rng = jax.random.PRNGKey(0)
model = CNN(num_classes=10)


rng, key = jax.random.split(rng)
variables = model.init(key, jnp.ones((1, 28, 28, 1)), train=False)
params = variables['params']
batch_stats = variables['batch_stats']

In [7]:
jax.tree_util.tree_map(jnp.shape, variables)

{'batch_stats': {'BatchNorm_0': {'mean': (16,), 'var': (16,)},
  'BatchNorm_1': {'mean': (32,), 'var': (32,)}},
 'params': {'BatchNorm_0': {'bias': (16,), 'scale': (16,)},
  'BatchNorm_1': {'bias': (32,), 'scale': (32,)},
  'Conv_0': {'bias': (16,), 'kernel': (5, 5, 1, 16)},
  'Conv_1': {'bias': (32,), 'kernel': (5, 5, 16, 32)},
  'Dense_0': {'bias': (10,), 'kernel': (32, 10)}}}

In [8]:
def compute_metrics(logits, labels):
  loss = jnp.mean(optax.softmax_cross_entropy(logits, jax.nn.one_hot(labels, num_classes=10)))
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  metrics = {
    'loss': loss,
    'accuracy': accuracy
  }
  return metrics


In [9]:
from flax.training import train_state
from typing import Any

class TrainState(train_state.TrainState):
    batch_stats: Any


learning_rate = 0.001
tx = optax.adam(learning_rate=learning_rate)


state = TrainState.create(
    apply_fn=model.apply,
    params=params,
    batch_stats=batch_stats,
    tx=tx,
)


In [10]:
@jax.jit
def train_step(state, batch):
  def loss_fn(params):
      logits, updates = state.apply_fn(
           {'params': params, 'batch_stats': state.batch_stats},
           batch['image'], train=True, mutable=['batch_stats'])
      loss = jnp.mean(optax.softmax_cross_entropy(
          logits=logits,
          labels=jax.nn.one_hot(batch['label'], num_classes=10))
      )
      return loss, (logits, updates)
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (_, (logits, updates)), grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  state = state.replace(batch_stats=updates['batch_stats'])


  metrics = compute_metrics(logits, batch['label'])
  return state, metrics


In [11]:
@jax.jit
def eval_step(state, batch):
  logits = state.apply_fn(
     {'params': state.params, 'batch_stats': state.batch_stats},
     batch['image'], train=False)
  return compute_metrics(logits, batch['label'])

In [12]:
def train_loop(state, train_ds, batch_size, epoch, rng):
  train_ds_size = len(train_ds['image'])
  steps_per_epoch = train_ds_size // batch_size

  perms = jax.random.permutation(rng, train_ds_size)
  perms = perms[:steps_per_epoch * batch_size]  # 완성되지 않은 배치는 생략합니다.
  perms = perms.reshape((steps_per_epoch, batch_size))

  batch_metrics = []
  for perm in perms:
    batch = {k: v[perm, ...] for k, v in train_ds.items()}
    state, metrics = train_step(state, batch)
    batch_metrics.append(metrics)


  training_batch_metrics = jax.device_get(batch_metrics)
  training_epoch_metrics = {
    k: sum([metrics[k] for metrics in training_batch_metrics])/steps_per_epoch
    for k in training_batch_metrics[0]}

  print('EPOCH: %d\nTraining loss: %.4f, accuracy: %.2f' % (epoch, training_epoch_metrics['loss'], training_epoch_metrics['accuracy'] * 100))
  return state


In [13]:
def eval_loop(state, test_ds, batch_size):
  eval_ds_size = test_ds['image'].shape[0]
  steps_per_epoch = eval_ds_size // batch_size
  batch_metrics = []
  for i in range(steps_per_epoch):
    batch = {k: v[i*batch_size:(i+1)*batch_size, ...] for k, v in test_ds.items()}
    metrics = eval_step(state, batch)
    batch_metrics.append(metrics)
  eval_batch_metrics = jax.device_get(batch_metrics)
  eval_batch_metrics = {
    k: sum([metrics[k] for metrics in eval_batch_metrics])/steps_per_epoch
    for k in eval_batch_metrics[0]}

  print('    Eval loss: %.4f, accuracy: %.2f' % (eval_batch_metrics['loss'], eval_batch_metrics['accuracy'] * 100))

In [14]:
train_epoch = 10
batch_size = 64
eval_batch_size = 100
for epoch in range(train_epoch):
  rng, key = jax.random.split(rng)
  state = train_loop(state, train_ds, batch_size, epoch, rng)
  eval_loop(state, test_ds, eval_batch_size)


EPOCH: 0
Training loss: 0.7520, accuracy: 84.84
    Eval loss: 0.2895, accuracy: 93.25
EPOCH: 1
Training loss: 0.2063, accuracy: 95.42
    Eval loss: 0.1725, accuracy: 95.88
EPOCH: 2
Training loss: 0.1349, accuracy: 96.67
    Eval loss: 0.1026, accuracy: 97.44
EPOCH: 3
Training loss: 0.1042, accuracy: 97.28
    Eval loss: 0.0963, accuracy: 97.46
EPOCH: 4
Training loss: 0.0883, accuracy: 97.62
    Eval loss: 0.0817, accuracy: 97.79
EPOCH: 5
Training loss: 0.0754, accuracy: 97.95
    Eval loss: 0.0773, accuracy: 97.88
EPOCH: 6
Training loss: 0.0679, accuracy: 98.14
    Eval loss: 0.0711, accuracy: 97.91
EPOCH: 7
Training loss: 0.0611, accuracy: 98.26
    Eval loss: 0.0806, accuracy: 97.54
EPOCH: 8
Training loss: 0.0560, accuracy: 98.43
    Eval loss: 0.1012, accuracy: 96.81
EPOCH: 9
Training loss: 0.0518, accuracy: 98.52
    Eval loss: 0.0799, accuracy: 97.40
