# Playing with checkpoints

In [26]:
from flax import nnx
import orbax.checkpoint as ocp
import jax
from jax import numpy as jnp
import numpy as np
import os

ckpt_dir = ocp.test_utils.erase_and_create_empty(os.getcwd() + '/data/checkpoints/')

class SelfAttentionHead(nnx.Module):
    def __init__(self, n_embed, head_dim, rngs: nnx.Rngs):
        self.query = nnx.Linear(n_embed, head_dim, rngs=rngs, use_bias=False)
        self.key = nnx.Linear(n_embed, head_dim, rngs=rngs, use_bias=False)
        self.value = nnx.Linear(n_embed, head_dim, rngs=rngs, use_bias=False)
    
    def __call__(self, x):
        B, T, C = x.shape
        q = self.query(x)
        k = self.key(x)

        attn = jnp.einsum('btd,bTd->btT', q, k) / jnp.sqrt(16)

        tril = jnp.tril(jnp.ones((T, T)))
        attn = jnp.where(tril[:T, :T] == 0, float('-inf'), attn)
        attn = jax.nn.softmax(attn)

        v = self.value(x)
        return attn @ v
    
class MultiHeadAttention(nnx.Module):
    def __init__(self, n_embed, n_head, head_dim, rngs: nnx.Rngs):
        self.heads = [SelfAttentionHead(n_embed, head_dim, rngs) for _ in range(n_head)]
        self.proj = nnx.Linear(n_embed, n_embed, rngs=rngs)

    def __call__(self, x):
        x = jnp.concatenate([head(x) for head in self.heads], axis=-1)
        return self.proj(x)
    
class FeedForward(nnx.Module):
    def __init__(self, n_embed, rngs: nnx.Rngs):
        self.fc1 = nnx.Linear(n_embed, 4 * n_embed, rngs=rngs)
        self.fc2 = nnx.Linear(4 * n_embed, n_embed, rngs=rngs)
    
    def __call__(self, x):
        return self.fc2(jax.nn.relu(self.fc1(x)))

class Block(nnx.Module):
    def __init__(self, n_embed, n_head, rngs: nnx.Rngs):
        self.sa_heads = MultiHeadAttention(n_embed, n_head, n_embed // n_head, rngs=rngs)
        self.ffwd = FeedForward(n_embed, rngs=rngs)
        self.ln1 = nnx.LayerNorm(n_embed, rngs=rngs)
        self.ln2 = nnx.LayerNorm(n_embed, rngs=rngs)

    def __call__(self, x):
        x = x + self.sa_heads(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

class GPT(nnx.Module):
    def __init__(self, block_size, vocab_size, n_embed, n_head, n_blocks, rngs: nnx.Rngs):
        self.block_size = block_size
        self.token_embedding_table = nnx.Embed(num_embeddings=vocab_size, features=n_embed, rngs=rngs)
        self.position_embedding_table = nnx.Embed(num_embeddings=self.block_size, features=n_embed, rngs=rngs)
        self.blocks = nnx.Sequential(*[Block(n_embed, n_head, rngs) for _ in range(n_blocks)])
        self.lm_head = nnx.Linear(n_embed, vocab_size, rngs=rngs)

    def __call__(self, x):
        B, T = x.shape
        x = self.token_embedding_table(x) + self.position_embedding_table(jnp.arange(T))
        x = self.blocks(x)
        logits = self.lm_head(x)
        return logits

# Instantiate the model and show we can run it.
model = GPT(8, 64, 32, 8, 4, rngs=nnx.Rngs(0))
x = jnp.zeros((1, 1), dtype=jnp.int32)
assert model(x).shape == (1, 1, 64)

In [27]:
_, state = nnx.split(model)
checkpointer = ocp.StandardCheckpointer()
checkpointer.save(ckpt_dir / 'state', state)

In [28]:
# Restore the checkpoint back to its `nnx.State` structure - need an abstract reference.
abstract_model = nnx.eval_shape(lambda: GPT(8, 64, 32, 8, 4, rngs=nnx.Rngs(0)))
graphdef, abstract_state = nnx.split(abstract_model)
print('The abstract NNX state (all leaves are abstract arrays):')
nnx.display(abstract_state)

state_restored = checkpointer.restore(ckpt_dir / 'state', abstract_state)
jax.tree.map(np.testing.assert_array_equal, state, state_restored)
print('NNX State restored: ')
nnx.display(state_restored)

# The model is now good to use!
model = nnx.merge(graphdef, state_restored)
assert model(x).shape == (1, 1, 64)

The abstract NNX state (all leaves are abstract arrays):




NNX State restored: 
