In [2]:
import sys, jax, numpy as np
print("Python:", sys.version)
print("JAX:", jax.__version__, "Devices:", jax.devices())
import platform, os
print("OS:", platform.platform())
print("CWD:", os.getcwd())

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


Python: 3.10.19 (main, Oct 10 2025, 08:52:10) [GCC 13.3.0]
JAX: 0.4.18 Devices: [CpuDevice(id=0)]
OS: Linux-6.6.87.2-microsoft-standard-WSL2-x86_64-with-glibc2.39
CWD: /home/huy/projects/vision_transformer


In [None]:
import sys, jax, flax, optax, numpy as np, scipy
print(sys.executable)                      
print("JAX", jax.__version__)              # 0.4.18
print("Flax", flax.__version__)            # 0.7.2
print("Optax", optax.__version__)          # 0.1.9
print("NumPy", np.__version__)             # 1.26.4
print("SciPy", scipy.__version__)          # 1.10.1


/home/huy/venvs/vit310/bin/python
JAX 0.4.18
Flax 0.7.2
Optax 0.1.9
NumPy 1.26.4
SciPy 1.10.1


In [None]:
!mkdir -p ~/projects/vit_weights
!cd ~/projects/vit_weights
!curl -L -O https://storage.googleapis.com/vit_models/imagenet21k/ViT-B_16.npz
!curl -L -O https://storage.googleapis.com/vit_models/imagenet21k/ViT-B_32.npz
!curl -L -O https://storage.googleapis.com/vit_models/imagenet21k/ViT-L_16.npz


# fine-tuned ImageNet-1k
# curl -L -O https://storage.googleapis.com/vit_models/imagenet21k+imagenet2012/ViT-B_16.npz


  pid, fd = os.forkpty()


  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  393M  100  393M    0     0  9747k      0  0:00:41  0:00:41 --:--:-- 10.7M
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  400M  100  400M    0     0  10.3M      0  0:00:38  0:00:38 --:--:-- 11.7M
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 1246M  100 1246M    0     0  10.9M      0  0:01:53  0:01:53 --:--:-- 11.7M


In [8]:
# Choose a model name exactly from this list:
# "ViT-B_16", "ViT-B_32", "ViT-L_16", "Mixer-B_16"
model_name = "ViT-B_16"

# Map model_name -> direct download URL (no gsutil required)
URLS = {
    "ViT-B_16": "https://storage.googleapis.com/vit_models/imagenet21k/ViT-B_16.npz",
    "ViT-B_32": "https://storage.googleapis.com/vit_models/imagenet21k/ViT-B_32.npz",
    "ViT-L_16": "https://storage.googleapis.com/vit_models/imagenet21k/ViT-L_16.npz",
    "Mixer-B_16": "https://storage.googleapis.com/mixer_models/imagenet21k/Mixer-B_16.npz",
    # (Optionally: ImageNet-1k fine-tuned head)
    # "ViT-B_16_ft": "https://storage.googleapis.com/vit_models/imagenet21k+imagenet2012/ViT-B_16.npz",
}

import os, urllib.request, pathlib
weights_dir = pathlib.Path.home() / "projects" / "vit_weights"
weights_dir.mkdir(parents=True, exist_ok=True)
dst = weights_dir / f"{model_name}.npz"

if not dst.exists():
    print(f"Downloading {model_name} …")
    urllib.request.urlretrieve(URLS[model_name], dst.as_posix())
    print("Saved to:", dst)
else:
    print("Already exists:", dst)

# Sanity check for the next cells:
assert dst.exists(), "Weights file was not downloaded"
print("OK ->", dst)

Downloading ViT-B_16 …
Saved to: /home/huy/projects/vit_weights/ViT-B_16.npz
OK -> /home/huy/projects/vit_weights/ViT-B_16.npz


In [5]:
!ls -lh /home/huy/projects/vit_weights


total 394M
-rw-r--r-- 1 huy huy 394M Oct 14 11:36 ViT-B_16.npz


In [6]:
from absl import logging
import flax
import jax
from matplotlib import pyplot as plt
import numpy as np
import optax
import tqdm

logging.set_verbosity(logging.INFO)

# Shows the number of available devices.
# In a CPU/GPU runtime this will be a single device.
# In a TPU runtime this will be 8 cores.
jax.local_devices()

[CpuDevice(id=0)]

In [25]:
import os, sys
repo = os.path.expanduser('~/projects/vision_transformer')  # path to the vit_jax repo
assert os.path.isdir(repo), f"Repo not found: {repo}"
if repo not in sys.path:
    sys.path.append(repo)

from vit_jax import models_vit, checkpoint
import vit_jax, inspect
print("vit_jax loaded from:", vit_jax.__file__)



vit_jax loaded from: /home/huy/projects/vision_transformer/vit_jax/__init__.py


In [27]:
from ml_collections import ConfigDict

def vit_config(kind='B', patch=16, num_classes=1000, image_size=224):
    cfg = ConfigDict()

    # patches
    cfg.patches = ConfigDict()
    cfg.patches.size = (patch, patch)

    # transformer sub-config
    tr = ConfigDict()
    if kind == 'B':
        cfg.hidden_size = 768
        tr.num_layers = 12
        tr.mlp_dim = 3072
        tr.num_heads = 12
    elif kind == 'L':
        cfg.hidden_size = 1024
        tr.num_layers = 24
        tr.mlp_dim = 4096
        tr.num_heads = 16
    else:
        raise ValueError("kind must be 'B' or 'L'")

    tr.dropout_rate = 0.0
    tr.attention_dropout_rate = 0.0
    cfg.transformer = tr

    cfg.classifier = 'token'
    cfg.representation_size = None
    cfg.num_classes = num_classes
    cfg.image_size = image_size
    return cfg

CFG = vit_config('B', 16, num_classes=1000, image_size=224)
CFG



classifier: token
hidden_size: 768
image_size: 224
num_classes: 1000
patches:
  size: !!python/tuple
  - 16
  - 16
representation_size: null
transformer:
  attention_dropout_rate: 0.0
  dropout_rate: 0.0
  mlp_dim: 3072
  num_heads: 12
  num_layers: 12

In [None]:
import os
from vit_jax import checkpoint
repo = "/home/huy/projects/vision_transformer"
ckpt = os.path.join(repo, "ViT-B_16.npz")   # path to the downloaded checkpoint
params = checkpoint.load(ckpt)


In [None]:
# ===== 1) Nhận diện số lớp của checkpoint =====
import os
from vit_jax import checkpoint

repo = "/home/huy/projects/vision_transformer"
ckpt = os.path.join(repo, "ViT-B_16.npz")   # path to the downloaded checkpoint

params_raw = checkpoint.load(ckpt)

# Some checkpoints have a nested 'params' key.
params = params_raw.get('params', params_raw)

def ckpt_num_classes(p):
    # head/kernel shape: [hidden_size, num_classes]
    try:
        return int(p['head']['kernel'].shape[-1])
    except Exception:
        return None

n_cls = ckpt_num_classes(params)
print("Checkpoint classes =", n_cls)  # 21843 (ImageNet-21k) hoặc 1000 (ImageNet-1k fine-tuned)


Checkpoint classes = 21843


In [None]:
# config model correspond to checkpoint
from ml_collections import ConfigDict

def vit_config(kind='B', patch=16, num_classes=1000, image_size=224):
    cfg = ConfigDict()
    cfg.patches = ConfigDict(); cfg.patches.size = (patch, patch)
    tr = ConfigDict()
    if kind == 'B':
        cfg.hidden_size = 768; tr.num_layers = 12; tr.mlp_dim = 3072; tr.num_heads = 12
    elif kind == 'L':
        cfg.hidden_size = 1024; tr.num_layers = 24; tr.mlp_dim = 4096; tr.num_heads = 16
    else:
        raise ValueError("kind must be 'B' or 'L'")
    tr.dropout_rate = 0.0; tr.attention_dropout_rate = 0.0
    cfg.transformer = tr
    cfg.classifier = 'token'
    cfg.representation_size = None
    cfg.num_classes = num_classes
    cfg.image_size = 224
    return cfg

# Nếu ckpt là 21k, đặt num_classes=21843; nếu là 1k, đặt 1000
CFG = vit_config('B', 16, num_classes=(n_cls or 1000), image_size=224)
CFG


classifier: token
hidden_size: 768
image_size: 224
num_classes: 21843
patches:
  size: !!python/tuple
  - 16
  - 16
representation_size: null
transformer:
  attention_dropout_rate: 0.0
  dropout_rate: 0.0
  mlp_dim: 3072
  num_heads: 12
  num_layers: 12

In [None]:

import jax, jax.numpy as jnp
from vit_jax import models_vit

model = models_vit.VisionTransformer(
    num_classes=CFG.num_classes,
    patches=CFG.patches,
    transformer=CFG.transformer,
    hidden_size=CFG.hidden_size,
    representation_size=CFG.representation_size,
    classifier=CFG.classifier,
)

rng = jax.random.PRNGKey(0)
dummy = jnp.ones([1, CFG.image_size, CFG.image_size, 3], jnp.float32)
variables = model.init(rng, dummy, train=False)

# thay params init bằng params trong checkpoint
variables = {**variables, 'params': params}

logits = model.apply(variables, dummy, train=False)
print("logits shape =", logits.shape)   # kỳ vọng: (1, num_classes)


logits shape = (1, 21843)


In [35]:
!curl -L -o ~/projects/vision_transformer/ViT-B_16-im21k+1k.npz \
  https://storage.googleapis.com/vit_models/imagenet21k+imagenet2012/ViT-B_16.npz






  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  331M  100  331M    0     0  10.2M      0  0:00:32  0:00:32 --:--:-- 11.6M


In [38]:
import jax, jax.numpy as jnp
from vit_jax import models_vit, checkpoint
from ml_collections import ConfigDict

# Config ViT-B/16 @224
def vit_config(num_classes=10, image_size=224):
    cfg = ConfigDict(); cfg.patches = ConfigDict(); cfg.patches.size = (16, 16)
    tr = ConfigDict()
    cfg.hidden_size = 768; tr.num_layers = 12; tr.mlp_dim = 3072; tr.num_heads = 12
    tr.dropout_rate = 0.0; tr.attention_dropout_rate = 0.0
    cfg.transformer = tr; cfg.classifier='token'; cfg.representation_size=None
    cfg.num_classes=num_classes; cfg.image_size=image_size
    return cfg

CFG = vit_config(num_classes=10, image_size=224)  # CIFAR-10 có 10 lớp

model = models_vit.VisionTransformer(
    num_classes=CFG.num_classes,
    patches=CFG.patches,
    transformer=CFG.transformer,
    hidden_size=CFG.hidden_size,
    representation_size=CFG.representation_size,
    classifier=CFG.classifier,
)

# Init skeleton variables
rng = jax.random.PRNGKey(0)
dummy = jnp.ones([1, 224, 224, 3], jnp.float32)
variables = model.init(rng, dummy, train=False)

# Load backbone pretrain (21k hoặc 1k đều được vì ta sẽ thay head)
ckpt = "/home/huy/projects/vision_transformer/ViT-B_16.npz"  # đường dẫn ckpt bạn có (21k hay 1k)
params = checkpoint.load(ckpt).get('params', None) or checkpoint.load(ckpt)

# Thay head -> 10 lớp, giữ backbone
params = params.copy()
params['head'] = {
    'kernel': jnp.zeros([CFG.hidden_size, 10], jnp.float32),
    'bias':   jnp.zeros([10], jnp.float32),
}
variables = {**variables, 'params': params}


In [39]:
import tensorflow as tf, tensorflow_datasets as tfds
tf.config.set_visible_devices([], 'GPU')  # tránh TF chiếm GPU nếu bạn dùng GPU cho JAX

IM_MEAN = tf.constant([0.485, 0.456, 0.406], tf.float32)
IM_STD  = tf.constant([0.229, 0.224, 0.225], tf.float32)

def preprocess_tf(image, label):
    image = tf.cast(image, tf.float32) / 255.0
    # Resize -> 224 (CIFAR-10 là 32 nên không cần giữ tỷ lệ)
    image = tf.image.resize(image, [224, 224], method='bicubic')
    image = (image - IM_MEAN) / IM_STD
    return image, tf.cast(label, tf.int32)

BATCH = 128

train_ds = tfds.load('cifar10', split='train[:90%]', as_supervised=True)
val_ds   = tfds.load('cifar10', split='train[90%:]', as_supervised=True)  # dùng 10% train làm val
test_ds  = tfds.load('cifar10', split='test', as_supervised=True)

train_ds = (train_ds.shuffle(10_000).map(preprocess_tf, num_parallel_calls=tf.data.AUTOTUNE)
                         .batch(BATCH).prefetch(tf.data.AUTOTUNE))
val_ds   = (val_ds.map(preprocess_tf, num_parallel_calls=tf.data.AUTOTUNE)
                   .batch(BATCH).prefetch(tf.data.AUTOTUNE))
test_ds  = (test_ds.map(preprocess_tf, num_parallel_calls=tf.data.AUTOTUNE)
                    .batch(BATCH).prefetch(tf.data.AUTOTUNE))


INFO:absl:Load dataset info from /home/huy/tensorflow_datasets/cifar10/3.0.2
INFO:absl:Reusing dataset cifar10 (/home/huy/tensorflow_datasets/cifar10/3.0.2)
INFO:absl:Constructing tf.data.Dataset cifar10 for split train[:90%], from /home/huy/tensorflow_datasets/cifar10/3.0.2
INFO:absl:Load dataset info from /home/huy/tensorflow_datasets/cifar10/3.0.2
INFO:absl:Reusing dataset cifar10 (/home/huy/tensorflow_datasets/cifar10/3.0.2)
INFO:absl:Constructing tf.data.Dataset cifar10 for split train[90%:], from /home/huy/tensorflow_datasets/cifar10/3.0.2
INFO:absl:Load dataset info from /home/huy/tensorflow_datasets/cifar10/3.0.2
INFO:absl:Reusing dataset cifar10 (/home/huy/tensorflow_datasets/cifar10/3.0.2)
INFO:absl:Constructing tf.data.Dataset cifar10 for split test, from /home/huy/tensorflow_datasets/cifar10/3.0.2


In [42]:
import tensorflow_datasets as tfds

PROJECT_TFDS = "/home/huy/projects/vision_transformer/data/tfds"  # any path 

train_ds = tfds.load('cifar10', split='train[:90%]',  as_supervised=True,
                     data_dir=PROJECT_TFDS, download=True)
val_ds   = tfds.load('cifar10', split='train[90%:]',   as_supervised=True,
                     data_dir=PROJECT_TFDS, download=True)
test_ds  = tfds.load('cifar10', split='test',          as_supervised=True,
                     data_dir=PROJECT_TFDS, download=True)
    # shows how splits can be sliced


INFO:absl:Load pre-computed DatasetInfo (eg: splits, num examples,...) from GCS: cifar10/3.0.2
INFO:absl:Load dataset info from /tmp/tmpew0d7g7wtfds
INFO:absl:Fields info.[citation, splits, supervised_keys, module_name] from disk and from code do not match. Keeping the one from code.
INFO:absl:Generating dataset cifar10 (/home/huy/projects/vision_transformer/data/tfds/cifar10/3.0.2)


[1mDownloading and preparing dataset 162.17 MiB (download: 162.17 MiB, generated: 132.40 MiB, total: 294.58 MiB) to /home/huy/projects/vision_transformer/data/tfds/cifar10/3.0.2...[0m


Dl Completed...: 0 url [00:00, ? url/s]
[AINFO:absl:Downloading https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz into /home/huy/projects/vision_transformer/data/tfds/downloads/cs.toronto.edu_kriz_cifar-10-binaryODHPtIjLh3oLcXirEISTO7dkzyKjRCuol6lV8Wc6C7s.tar.gz.tmp.421f9a8e3e544a06b3544645692956b9...
Dl Completed...:   0%|          | 0/1 [00:00<?, ? url/s]
Extraction completed...: 100%|██████████| 8/8 [00:16<00:00,  2.12s/ file]
Dl Size...: 100%|██████████| 162/162 [00:16<00:00,  9.54 MiB/s]
Dl Completed...: 100%|██████████| 1/1 [00:16<00:00, 16.99s/ url]
Generating splits...:   0%|          | 0/2 [00:00<?, ? splits/s]INFO:absl:Done writing /home/huy/projects/vision_transformer/data/tfds/cifar10/3.0.2.incompleteJD7TRB/cifar10-train.tfrecord*. Number of examples: 50000 (shards: [50000])
Generating splits...:  50%|█████     | 1/2 [00:07<00:07,  7.69s/ splits]INFO:absl:Done writing /home/huy/projects/vision_transformer/data/tfds/cifar10/3.0.2.incompleteJD7TRB/cifar10-test.tfrecord

[1mDataset cifar10 downloaded and prepared to /home/huy/projects/vision_transformer/data/tfds/cifar10/3.0.2. Subsequent calls will reuse this data.[0m


In [None]:
import tensorflow_datasets as tfds, jax.numpy as jnp

for x, y in tfds.as_numpy(train_ds.take(1)):
    print("train batch shape:", x.shape, x.dtype)
    #Wrong (128, 224, 224, 3) float32


train batch shape: (32, 32, 3) uint8


2025-10-14 16:18:22.907502: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


In [45]:
import tensorflow as tf

IM_MEAN = tf.constant([0.485, 0.456, 0.406], tf.float32)
IM_STD  = tf.constant([0.229, 0.224, 0.225], tf.float32)
IMG_SIZE = CFG.image_size  # 224 (hoặc 384 nếu dùng ckpt 384)

def preprocess_tf(image, label):
    image = tf.cast(image, tf.float32) / 255.0
    image = tf.image.resize(image, [IMG_SIZE, IMG_SIZE], method='bicubic')

    # đảm bảo có 3 kênh
    if image.shape.rank == 2:
        image = tf.expand_dims(image, -1)
    if image.shape[-1] == 1:
        image = tf.repeat(image, 3, axis=-1)

    image = (image - IM_MEAN) / IM_STD
    return image, tf.cast(label, tf.int32)


In [51]:
import tensorflow_datasets as tfds

BATCH = 32
PROJECT_TFDS = "/home/huy/projects/vision_transformer/data/tfds"

train_ds = (tfds.load('cifar10', split='train[:90%]', as_supervised=True, data_dir=PROJECT_TFDS)
              .shuffle(10_000)
              .map(preprocess_tf, num_parallel_calls=tf.data.AUTOTUNE)
              .batch(BATCH)
              .prefetch(tf.data.AUTOTUNE))

val_ds = (tfds.load('cifar10', split='train[90%:]', as_supervised=True, data_dir=PROJECT_TFDS)
            .map(preprocess_tf, num_parallel_calls=tf.data.AUTOTUNE)
            .batch(BATCH)
            .prefetch(tf.data.AUTOTUNE))

test_ds = (tfds.load('cifar10', split='test', as_supervised=True, data_dir=PROJECT_TFDS)
             .map(preprocess_tf, num_parallel_calls=tf.data.AUTOTUNE)
             .batch(BATCH)
             .prefetch(tf.data.AUTOTUNE))


INFO:absl:Load dataset info from /home/huy/projects/vision_transformer/data/tfds/cifar10/3.0.2
INFO:absl:Reusing dataset cifar10 (/home/huy/projects/vision_transformer/data/tfds/cifar10/3.0.2)
INFO:absl:Constructing tf.data.Dataset cifar10 for split train[:90%], from /home/huy/projects/vision_transformer/data/tfds/cifar10/3.0.2
INFO:absl:Load dataset info from /home/huy/projects/vision_transformer/data/tfds/cifar10/3.0.2
INFO:absl:Reusing dataset cifar10 (/home/huy/projects/vision_transformer/data/tfds/cifar10/3.0.2)
INFO:absl:Constructing tf.data.Dataset cifar10 for split train[90%:], from /home/huy/projects/vision_transformer/data/tfds/cifar10/3.0.2
INFO:absl:Load dataset info from /home/huy/projects/vision_transformer/data/tfds/cifar10/3.0.2
INFO:absl:Reusing dataset cifar10 (/home/huy/projects/vision_transformer/data/tfds/cifar10/3.0.2)
INFO:absl:Constructing tf.data.Dataset cifar10 for split test, from /home/huy/projects/vision_transformer/data/tfds/cifar10/3.0.2


In [52]:
import tensorflow_datasets as tfds, jax.numpy as jnp
for x, y in tfds.as_numpy(train_ds.take(1)):
    print("batched train shape:", x.shape, x.dtype)  # kỳ vọng: (B, 224, 224, 3) float32


batched train shape: (32, 224, 224, 3) float32


2025-10-14 16:23:55.899510: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


In [53]:
params = params.copy()
params['head'] = {
    'kernel': jnp.zeros([CFG.hidden_size, 10], jnp.float32),
    'bias': jnp.zeros([10], jnp.float32),
}


In [54]:
import jax, jax.numpy as jnp
from jax import tree_util as jtu    # <--- dùng tree_util
import optax, tensorflow_datasets as tfds

# ----- tách params -----
full_params = params                           # params đã build + load ckpt
backbone_params = {k: v for k, v in full_params.items() if k != 'head'}

# stop gradient cho toàn bộ backbone (trên toàn pytree)
backbone_params = jtu.tree_map(jax.lax.stop_gradient, backbone_params)

head_params = full_params['head']              

# optimizer chỉ cho head
tx = optax.sgd(learning_rate=0.1, momentum=0.9, nesterov=True)
opt_state = tx.init(head_params)

# loss chỉ nhận head; backbone là hằng số (đã stop_gradient)
def loss_fn_head(head, x, y):
    merged = {'head': head, **backbone_params}
    logits = model.apply({'params': merged}, x, train=True)
    loss = optax.softmax_cross_entropy(logits, jax.nn.one_hot(y, 10)).mean()
    acc  = (logits.argmax(-1) == y).mean()
    return loss, acc

grad_fn = jax.value_and_grad(loss_fn_head, has_aux=True)

@jax.jit
def train_step(head, opt_state, x, y):
    (loss, acc), grads = grad_fn(head, x, y)
    updates, opt_state = tx.update(grads, opt_state, head)
    head = optax.apply_updates(head, updates)
    return head, opt_state, loss, acc

def evaluate(p_head, ds):
    merged = {'head': p_head, **backbone_params}
    tot = cor = 0
    for x, y in tfds.as_numpy(ds):
        x, y = jnp.asarray(x), jnp.asarray(y)
        logits = model.apply({'params': merged}, x, train=False)
        cor += (logits.argmax(-1) == y).sum()
        tot += y.shape[0]
    return float(cor) / tot

# ===== train loop (giảm batch nếu cần, ví dụ BATCH=32/16) =====
for epoch in range(10):
    n, loss_sum, acc_sum = 0, 0.0, 0.0
    for x, y in tfds.as_numpy(train_ds):
        x, y = jnp.asarray(x), jnp.asarray(y)
        head_params, opt_state, loss, acc = train_step(head_params, opt_state, x, y)
        loss_sum += float(loss); acc_sum += float(acc); n += 1
    val_acc = evaluate(head_params, val_ds)
    print(f"Epoch {epoch+1:02d}: loss={loss_sum/n:.4f}  train_acc={acc_sum/n:.4f}  val_acc={val_acc:.4f}")

test_acc = evaluate(head_params, test_ds)
print("Test top-1 accuracy:", test_acc)

# nếu muốn lưu lại full params:
params = {'head': head_params, **backbone_params}


Epoch 01: loss=0.2673  train_acc=0.9170  val_acc=0.9230
Epoch 02: loss=0.1976  train_acc=0.9345  val_acc=0.9250
Epoch 03: loss=0.1800  train_acc=0.9399  val_acc=0.9266
Epoch 04: loss=0.1688  train_acc=0.9428  val_acc=0.9242
Epoch 05: loss=0.1616  train_acc=0.9452  val_acc=0.9228
Epoch 06: loss=0.1559  train_acc=0.9468  val_acc=0.9264
Epoch 07: loss=0.1515  train_acc=0.9483  val_acc=0.9240
Epoch 08: loss=0.1473  train_acc=0.9497  val_acc=0.9244
Epoch 09: loss=0.1440  train_acc=0.9504  val_acc=0.9246
Epoch 10: loss=0.1421  train_acc=0.9517  val_acc=0.9224
Test top-1 accuracy: 0.9256


In [61]:
params = {'head': head_params, **backbone_params}


In [63]:
import jax, flax, pkgutil
print("jax:", jax.__version__)
import orbax
print("orbax-checkpoint:", getattr(orbax, "__version__", "unknown"))


jax: 0.4.18
orbax-checkpoint: unknown


In [65]:
import os, json, time
from flax import serialization

ts = time.strftime("%Y-%m-%d_%H-%M-%S")
ckpt_dir = os.path.join("ckpts_vit", ts)
os.makedirs(ckpt_dir, exist_ok=True)

# 1) gộp full params (khuyên dùng)
full_params = {'head': head_params, **backbone_params}

# 2) save params & opt_state (msgpack)
with open(os.path.join(ckpt_dir, "params.msgpack"), "wb") as f:
    f.write(serialization.to_bytes(full_params))

with open(os.path.join(ckpt_dir, "opt_state.msgpack"), "wb") as f:
    f.write(serialization.to_bytes(opt_state))

# 3) meta (bạn có thể thêm label_map, num_classes, patch_size…)
meta = {
    "epoch": 10,
    "test_acc": float(test_acc),
    "optimizer": {"type": "sgd", "lr": 0.1, "momentum": 0.9, "nesterov": True},
    "notes": "ViT head-only finetune; backbone frozen",
}
with open(os.path.join(ckpt_dir, "meta.json"), "w", encoding="utf-8") as f:
    json.dump(meta, f, ensure_ascii=False, indent=2)

print("Saved checkpoint to:", ckpt_dir)



Saved checkpoint to: ckpts_vit/2025-10-15_08-46-27


In [67]:
print(os.path.abspath(ckpt_dir))

/home/huy/projects/vision_transformer/ckpts_vit/2025-10-15_08-46-27


In [None]:
val_acc = evaluate(head_params, val_ds)   
print("Val acc after load:", val_acc)