In [9]:
import jax
import jax.numpy as jnp
import jax.random as random
import flax
from transformers import FlaxAutoModel

In [5]:
model = FlaxAutoModel.from_pretrained("gpt2")

Downloading: 100%|██████████| 475M/475M [00:15<00:00, 31.7MB/s] 


In [64]:
model_params, tree_struct = jax.tree_flatten(jax.tree_map(lambda x: x.size, model.params))
model_params = sum(model_params)
print(f"Model params: {(model_params / 10**6):.2f}M")

emb_params, _ = jax.tree_flatten(jax.tree_map(
    lambda x: x.size,
    {"wpe": model.params["wpe"], "wte": model.params["wte"]}
))
emb_params = sum(emb_params)

print(f"Emb params: {(emb_params / 10**6):.2f}M")

print(f"Model non-emb params: {(model_params - emb_params) / 10**6:.2f}M")

# LR(N ) ≈ 0.003239 + −0.0001395 log(N ) (D.1)
lr = 0.003239 + (-0.0001395) * jnp.log(model_params)
print(f"Learning rate: {lr:e}")

Model params: 124.44M
Emb params: 39.38M
Model non-emb params: 85.06M
Learning rate: 6.388132e-04
