<a href="https://colab.research.google.com/github/yblee110/jax-flax-book/blob/main/ch04_3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install jax==0.4.24
!pip install flax==0.7.5
!pip install optax==0.1.7
!pip install datasets

Collecting jax==0.4.24
  Downloading jax-0.4.24-py3-none-any.whl (1.8 MB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.8 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.3/1.8 MB[0m [31m8.5 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m29.1 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: jax
  Attempting uninstall: jax
    Found existing installation: jax 0.4.23
    Uninstalling jax-0.4.23:
      Successfully uninstalled jax-0.4.23
Successfully installed jax-0.4.24
Collecting flax==0.7.5
  Downloading flax-0.7.5-py3-none-any.whl (244 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m244.4/244.4 kB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: flax
  Attempting uninstall: flax
    Found existing installation: flax 0.8.1
    Uninstalling flax-0.8.1:
      Succes

In [2]:
from functools import partial
import albumentations as A
import datasets
from flax.jax_utils import prefetch_to_device
import jax
import jax.numpy as jnp
import numpy as np




# 학습에 사용할 디바이스가 여러 개인 경우 샤딩을 사용합니다.
def shard(x, devices):


   B, *D = x.shape
   num_devices = len(devices)
   return jnp.reshape(x, [num_devices, B // num_devices, *D])






class build_dataset_providers:
   def __init__(self, rng, dtype=jnp.float32, test_only=False, batch_size = 256, val_batch_size = 250):
       self.dtype = dtype


       data = datasets.load_dataset('cifar10').with_format("numpy")
       data = {
           "train": {
               "img": data["train"]["img"],
               "label": data["train"]["label"],
           },
           "test": {
               "img": data["test"]["img"],
               "label": data["test"]["label"],
           },
       }


       self.num_classes = 10
       self.input_size = [32, 32, 3]


       self.devices = jax.local_devices()
       # cardinality: 중복 수치를 의미합니다.
       self.cardinality = {}
       # iter_len:반복하는 횟수를 의미합니다.
       self.iter_len = {}
       self.rng = {}


       gen_provider = self.gen_jax_w_alb_provider


       self.provider = {}


       if not (test_only):
           self.rng["train"] = rng
           self.provider["train"] = gen_provider(
               "train",
               data["train"],
               batch_size,
               shuffle=True,
               drop_remainder=True,
           )
       self.rng["test"] = rng
       self.provider["test"] = gen_provider(
           "test", data["test"], val_batch_size, drop_remainder=False
       )


       print("=" * 50)
       print("data providers are built as follows:")
       print("- Data cardinality     :", self.cardinality)
       print("- Number of iterations :", self.iter_len)
       print("=" * 50, "\n")


   def gen_jax_w_alb_provider(
       self,
       split,
       data,
       batch_size,
       shuffle=False,
       drop_remainder=True,
   ):


       self.cardinality[split] = len(data["img"])
       self.iter_len[split] = (
           self.cardinality[split] // batch_size
           if drop_remainder
           else int(np.ceil(self.cardinality[split] / batch_size))
       )


       transform = A.Compose(
           [
               A.HorizontalFlip(p=0.5),
               A.PadIfNeeded(min_height=40, min_width=40, p=1),
               A.RandomCrop(height=32, width=32, p=1),
           ]
       )


       def provider():
           indices = np.arange(self.cardinality[split])




           if shuffle:
               self.rng[split], key = jax.random.split(self.rng[split])
               indices = jax.random.shuffle(key, indices)


           for batch in range(self.iter_len[split]):
               curr_idx = indices[batch * batch_size : (batch + 1) * batch_size]
               batch_data = {k: d[curr_idx] for k, d in data.items()}


               if split == "train":
                   batch_data["img"] = np.stack(
                       [transform(image=image)["image"] for image in batch_data["img"]]
                   )


               batch_data["img"] = jax.device_put(
                   batch_data["img"], jax.devices("cpu")[0]
               )
               batch_data["img"] = batch_data["img"].astype(self.dtype) / 255
               yield {k: shard(d, self.devices) for k, d in batch_data.items()}
           print(len(batch_data['img']))


       return lambda: prefetch_to_device(provider(), size=10)


rng = jax.random.PRNGKey(2048)
rng, key = jax.random.split(rng)


datasets = build_dataset_providers(key)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Downloading readme:   0%|          | 0.00/5.16k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/120M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/23.9M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/50000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/10000 [00:00<?, ? examples/s]

data providers are built as follows:
- Data cardinality     : {'train': 50000, 'test': 10000}
- Number of iterations : {'train': 195, 'test': 40}



In [3]:
from typing import Any


from flax import linen as nn
import jax
import jax.numpy as jnp
from transformers import FlaxCLIPModel


ModuleDef = Any


def preprocess_for_CLIP(image):
   """
    CLIP을 위한 전처리
   """
   image = image.transpose(0, 3, 1, 2)
   B, D = image.shape[:2]
   mean = jnp.array([0.48145466, 0.4578275, 0.40821073]).reshape(1, 3, 1, 1)
   std = jnp.array([0.26862954, 0.26130258, 0.27577711]).reshape(1, 3, 1, 1)
   image = jax.image.resize(
       image, (B, D, 224, 224), "bicubic"
   )  # # 이미지가 직사각형 모양shape이라고 가정합니다.
   image = (image - mean.astype(image.dtype)) / std.astype(image.dtype)
   return image


def model(num_classes=10, dtype=jnp.float32):
   _model = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32", dtype=dtype)


   class CLIP(nn.Module):
       num_classes: int
       dtype: Any = jnp.float32


       @nn.compact
       def __call__(self, x, train=False):
           dense = nn.Dense


           x = preprocess_for_CLIP(x)
           emb = _model.get_image_features(x)
           return dense(
               features=self.num_classes, dtype=self.dtype, name="classifier"
           )(emb)


   return CLIP(num_classes, dtype)


model = model(num_classes=10, dtype=jnp.float32)


config.json:   0%|          | 0.00/4.19k [00:00<?, ?B/s]

flax_model.msgpack:   0%|          | 0.00/605M [00:00<?, ?B/s]

In [4]:
from typing import Any, Callable


from flax import struct
from flax.core import FrozenDict
from flax.training import common_utils, train_state
from flax.jax_utils import replicate
import jax
import jax.numpy as jnp
import optax




def initialized(key, input_size, model):
   """
   PRNG를 사용해 주어진 모델 파라미터 초기화하기


   Args:
       rng: 랜덤 키를 사용하기 위한 RPNG 키
       input_size: 입력 데이터 사이즈


   """


   input_size = (1, *input_size)


   def init_model():
       return model.init(key, jnp.ones(input_size, model.dtype), train=False)


   variables = jax.jit(init_model, backend="cpu")()
   return variables["params"], variables.get("batch_stats", FrozenDict({}))




class TrainState(train_state.TrainState):
   batch_stats: Any






def l2_weight_decay(params, grads, weight_decay):


   params_flat, treedef = jax.tree_util.tree_flatten(params)
   grads_flat = treedef.flatten_up_to(grads)
   grads_flat = [
       grad + param * weight_decay for param, grad in zip(params_flat, grads_flat)
   ]
   new_grads = jax.tree_util.tree_unflatten(treedef, grads_flat)
   return new_grads




def create_train_state(
   rng, model, input_size, learning_rate_fn, params=None, batch_stats=None
):


   if params is None:
       params, batch_stats = initialized(rng, input_size, model)
   else:
       features_dict = {}


       def features_check(keys, variables):
           v = variables[list(variables.keys())[0]]
           features_dict[keys[-1]] = v.shape[-1]


       def rebuild_tree(frozen_dict, K):
           for k, layer in frozen_dict.items():
               if "mask" not in k:
                   if (
                       any(
                           [
                               not (isinstance(p, FrozenDict) or isinstance(p, dict))
                               for _, p in layer.items()
                           ]
                       )
                       and layer
                   ):
                       features_check(K + [k], layer)
                   else:
                       rebuild_tree(layer, K + [k])


       rebuild_tree(params, [])
       model.features_dict = features_dict


   tx = optax.sgd(
       learning_rate=learning_rate_fn,
       momentum=0.9,
       nesterov=True,
   )
   state = TrainState.create(
       apply_fn=model.apply,
       params=params,
       tx=tx,
       batch_stats=batch_stats,
   )
   return state




def create_train_step(weight_decay):


   @jax.jit
   def train_step(state, batch, dropout_rng):
       def forward(params):
           variables = {"params": params, "batch_stats": state.batch_stats}
           logits, new_state = state.apply_fn(
               variables,
               batch["img"],
               rngs=dict(dropout=dropout_rng),
               mutable=["batch_stats"],
           )


           # objective function
           one_hot_labels = common_utils.onehot(
               batch["label"], num_classes=logits.shape[-1]
           )
           loss = jnp.mean(
               optax.softmax_cross_entropy(logits=logits, labels=one_hot_labels)
           )
           return loss, (new_state, logits, loss)


       grad_fn = jax.value_and_grad(forward, has_aux=True)
       aux, grads = grad_fn(state.params)
       new_state, logits, loss = aux[1]


       grads = jax.lax.pmean(grads, axis_name="batch")
       grads = l2_weight_decay(state.params, grads, weight_decay)


       accuracy = jnp.mean(jnp.argmax(logits, -1) == batch["label"])
       new_state = state.apply_gradients(
           grads=grads, batch_stats=new_state["batch_stats"]
       )


       metrics = {
           "loss": loss,
           "accuracy": accuracy * 100,
       }
       metrics = jax.lax.pmean(metrics, axis_name="batch")


       return new_state, metrics, dropout_rng


   train_step = jax.pmap(train_step, axis_name="batch")


   cross_replica_mean = jax.pmap(lambda x: jax.lax.pmean(x, "batch"), "batch")


   def sync_batch_stats(state):
       return state.replace(batch_stats=cross_replica_mean(state.batch_stats))


   return train_step, sync_batch_stats




def create_eval_step(num_classes):


   @jax.jit
   def eval_step(state, batch):
       variables = {"params": state.params, "batch_stats": state.batch_stats}
       logits = state.apply_fn(variables, batch["img"], train=False)


       # objective function
       one_hot_labels = common_utils.onehot(batch["label"], num_classes=num_classes)
       loss = jnp.mean(
           optax.softmax_cross_entropy(logits=logits, labels=one_hot_labels)
       )


       accuracy = jnp.mean(jnp.argmax(logits, -1) == batch["label"])
       metrics = {
           "loss": loss,
           "accuracy": accuracy * 100,
       }
       metrics = jax.lax.pmean(metrics, axis_name="batch")
       return metrics


   eval_step = jax.pmap(eval_step, axis_name="batch")
   return eval_step


In [5]:
decay_points = [0.3, 0.6, 0.8]
train_epoch = 20
decay_rate = 0.2
learning_rate = 1e-1


learning_rate_fn = optax.piecewise_constant_schedule(learning_rate,
       {
           int(dp * train_epoch * datasets.iter_len["train"]): decay_rate
           for dp in decay_points
       },
   )


rng, key = jax.random.split(rng)
state = create_train_state(
       key, model, datasets.input_size, learning_rate_fn
   )
state = replicate(state)


In [6]:
!pip install orbax


from orbax import checkpoint
import os


train_path = os.getenv('HOME') + '/clip/'


ckpt_mgr = checkpoint.CheckpointManager(
       train_path,
       checkpoint.Checkpointer(checkpoint.PyTreeCheckpointHandler()),
       checkpoint.CheckpointManagerOptions(
           create=not (os.path.isdir(train_path)),
           max_to_keep=3,
           step_prefix="model_epoch",
       ),
   )


Collecting orbax
  Downloading orbax-0.1.9.tar.gz (1.6 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: orbax
  Building wheel for orbax (setup.py) ... [?25l[?25hdone
  Created wheel for orbax: filename=orbax-0.1.9-py3-none-any.whl size=1498 sha256=43fc6b28c811d1ab1b68f8fbc42976f78c1fa105844cd03bd34c18da9770f9b5
  Stored in directory: /root/.cache/pip/wheels/14/7a/98/b955a4db98b54317c311ee32367994ca530721c62a87ec56a7
Successfully built orbax
Installing collected packages: orbax
Successfully installed orbax-0.1.9


In [7]:
from collections import OrderedDict
import json
import os


from flax import jax_utils
import jax.numpy as jnp
import numpy as np


class summary:
   def __init__(
       self,
   ):
       self.holder = {}


   def assign(self, key_value, num_data=1):
       for k, v in key_value.items():
           v = np.array(jax_utils.unreplicate(v)).item()
           if k in self.holder:
               self.holder[k] = [
                   self.holder[k][0] + v * num_data,
                   self.holder[k][1] + num_data,
               ]
           else:
               self.holder[k] = [v * num_data, num_data]


   def reset(self, keys=None):
       if keys is None:
           self.holder = {}
       else:
           for k in keys:
               del self.holder[k]


   def result(self, keys):
       return {k: self.holder[k][0] / self.holder[k][1] for k in keys}


In [8]:
weight_decay = 5e-4


train_step, sync_batch_stats = create_train_step(weight_decay)
eval_step = create_eval_step(10)


logger = summary()
update_rng = jax_utils.replicate(rng)


In [9]:
import time
from flax.training import orbax_utils


do_log = 30


tic = time.time()
for epoch in range(train_epoch):
   # 학습 루프
   for batch in datasets.provider["train"]():
       state, metrics, update_rng = train_step(state, batch, update_rng)
       metrics = {"train/" + k: v for k, v in metrics.items()}


       logger.assign(metrics, num_data=batch["img"].shape[1])
       step = int(state.step.mean().item())
       if step % do_log == 0:
           train_time = time.time() - tic


           local_result = logger.result(metrics.keys())
           print(
               "Global step {0:6d}: loss = {1:0.4f}, \
               acc = {2:0.2f} ({3:1.3f} sec/step)".format(
                   step,
                   local_result["train/loss"],
                   local_result["train/accuracy"],
                   train_time / do_log,
               )
           )


           tic = time.time()


   epoch += 1
   if len(state.batch_stats) > 0:
       state = sync_batch_stats(state)


   state_ = jax_utils.unreplicate(state)


   save_args = orbax_utils.save_args_from_target(state_)
   ckpt_mgr.save(epoch, state_, save_kwargs={"save_args": save_args})


   train_result = logger.result(metrics.keys())


   test_tic = time.time()
   # 평가 루프
   for batch in datasets.provider["test"]():
       metrics = eval_step(state, batch)
       metrics = {"test/" + k: v for k, v in metrics.items()}


       logger.assign(metrics, num_data=batch["img"].shape[1])


   eval_result = logger.result(metrics.keys())
   print("=" * 50)
   print(
       "Epoch {0:3d}:\n\tTest loss = {1:0.4f}, Test acc = {2:0.2f}".format(
           epoch, eval_result["test/loss"], eval_result["test/accuracy"]
       )
   )
   print("=" * 50)


   logger.reset()


   # 소요 시간 계산하기
   tic = tic + time.time() - test_tic


  indices = jax.random.shuffle(key, indices)


Global step     30: loss = 0.8415,                acc = 74.95 (1.619 sec/step)
Global step     60: loss = 0.5913,                acc = 81.97 (0.696 sec/step)
Global step     90: loss = 0.4964,                acc = 84.63 (0.699 sec/step)
Global step    120: loss = 0.4430,                acc = 86.12 (0.671 sec/step)
Global step    150: loss = 0.4104,                acc = 87.04 (0.668 sec/step)
Global step    180: loss = 0.3879,                acc = 87.72 (0.680 sec/step)
256
250
Epoch   1:
	Test loss = 0.2101, Test acc = 93.11


  indices = jax.random.shuffle(key, indices)


Global step    210: loss = 0.2617,                acc = 91.85 (0.693 sec/step)
Global step    240: loss = 0.2678,                acc = 91.35 (0.676 sec/step)
Global step    270: loss = 0.2671,                acc = 91.30 (0.669 sec/step)
Global step    300: loss = 0.2597,                acc = 91.47 (0.679 sec/step)
Global step    330: loss = 0.2599,                acc = 91.46 (0.684 sec/step)
Global step    360: loss = 0.2620,                acc = 91.38 (0.680 sec/step)
256
Global step    390: loss = 0.2596,                acc = 91.42 (0.675 sec/step)
250
Epoch   2:
	Test loss = 0.1962, Test acc = 93.73
Global step    420: loss = 0.2519,                acc = 91.51 (0.680 sec/step)
Global step    450: loss = 0.2506,                acc = 91.66 (0.678 sec/step)
Global step    480: loss = 0.2482,                acc = 91.76 (0.679 sec/step)
Global step    510: loss = 0.2469,                acc = 91.81 (0.681 sec/step)
Global step    540: loss = 0.2459,                acc = 91.79 (0.680 sec/s