In [11]:
import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'

import jax
import flax
import jax.numpy as jnp
import jax.random as random
import flax.linen as nn

import matplotlib.pyplot as plt

In [2]:
model = nn.Dense(features=5)
model

Dense(
    # attributes
    features = 5
    use_bias = True
    dtype = None
    param_dtype = float32
    precision = None
    kernel_init = init
    bias_init = zeros
    dot_general = None
    dot_general_cls = None
)

In [3]:
key1, key2 = random.split(random.key(0))
x = random.normal(key1, (10,))
params = model.init(key2, x)
jax.tree_util.tree_map(lambda x: x.shape, params)

2024-03-11 21:34:20.862995: W external/xla/xla/service/gpu/nvptx_compiler.cc:742] The NVIDIA driver's CUDA version is 11.6 which is older than the ptxas CUDA version (11.8.89). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


{'params': {'bias': (5,), 'kernel': (10, 5)}}

In [5]:
model.apply(params, x)

Array([-1.3721199 ,  0.611315  ,  0.64428365,  2.2192967 , -1.1271119 ],      dtype=float32)

In [9]:
x.devices()

{cuda(id=0)}

In [10]:
# Set problem dimensions.
n_samples = 20
x_dim = 10
y_dim = 5

# Generate random ground truth W and b.
key = random.key(0)
k1, k2 = random.split(key)
W = random.normal(k1, (x_dim, y_dim))
b = random.normal(k2, (y_dim,))
# Store the parameters in a FrozenDict pytree.
true_params = flax.core.freeze({'params': {'bias': b, 'kernel': W}})

# Generate samples with additional noise.
key_sample, key_noise = random.split(k1)
x_samples = random.normal(key_sample, (n_samples, x_dim))
y_samples = jnp.dot(x_samples, W) + b + 0.1 * random.normal(key_noise,(n_samples, y_dim))
print('x shape:', x_samples.shape, '; y shape:', y_samples.shape)

x shape: (20, 10) ; y shape: (20, 5)


In [13]:
# Same as JAX version but using model.apply().
@jax.jit
def mse(params, x_batched, y_batched):
  # Define the squared loss for a single pair (x,y)
  def squared_error(x, y):
    pred = model.apply(params, x)
    return jnp.inner(y-pred, y-pred) / 2.0
  # Vectorize the previous to compute the average of the loss on all samples.
  return jnp.mean(jax.vmap(squared_error)(x_batched,y_batched), axis=0)

In [14]:
learning_rate = 0.3  # Gradient step size.
print('Loss for "true" W,b: ', mse(true_params, x_samples, y_samples))
loss_grad_fn = jax.value_and_grad(mse)

@jax.jit
def update_params(params, learning_rate, grads):
  params = jax.tree_util.tree_map(
      lambda p, g: p - learning_rate * g, params, grads)
  return params

for i in range(101):
  # Perform one gradient update.
  loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
  params = update_params(params, learning_rate, grads)
  if i % 10 == 0:
    print(f'Loss step {i}: ', loss_val)

Loss for "true" W,b:  0.023639798
Loss step 0:  35.343876
Loss step 10:  0.5150507
Loss step 20:  0.114045255
Loss step 30:  0.039395206
Loss step 40:  0.019940186
Loss step 50:  0.014217614
Loss step 60:  0.012428728
Loss step 70:  0.011851474
Loss step 80:  0.011662136
Loss step 90:  0.0115995305
Loss step 100:  0.011578723


In [15]:
import optax
tx = optax.adam(learning_rate=learning_rate)
opt_state = tx.init(params)
loss_grad_fn = jax.value_and_grad(mse)

In [16]:
for i in range(101):
  loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
  updates, opt_state = tx.update(grads, opt_state)
  params = optax.apply_updates(params, updates)
  if i % 10 == 0:
    print('Loss step {}: '.format(i), loss_val)

Loss step 0:  0.0115776425
Loss step 10:  0.26137036
Loss step 20:  0.076836
Loss step 30:  0.03648498
Loss step 40:  0.02203123
Loss step 50:  0.016186504
Loss step 60:  0.012997282
Loss step 70:  0.012026423
Loss step 80:  0.011765248
Loss step 90:  0.011645812
Loss step 100:  0.011585565


In [18]:
opt_state

(ScaleByAdamState(count=Array(101, dtype=int32), mu={'params': {'bias': Array([8.3893799e-05, 8.6121960e-05, 8.3097213e-05, 8.4847532e-05,
        2.8190637e-04], dtype=float32), 'kernel': Array([[ 3.9197803e-05,  3.8061680e-05,  3.5391680e-05,  3.9148064e-05,
         -6.1675746e-06],
        [-4.9616185e-05, -4.7960111e-05, -4.9128044e-05, -4.7233192e-05,
          1.5249492e-05],
        [ 1.2955531e-04,  1.3588111e-04,  1.3098563e-04,  1.3318549e-04,
          1.0369407e-04],
        [ 4.6785521e-05,  4.6059609e-05,  4.7439731e-05,  4.4369255e-05,
         -1.6793149e-04],
        [ 3.1651914e-05,  3.2143878e-05,  3.2318749e-05,  3.1883526e-05,
          5.3564363e-05],
        [-1.2694288e-04, -1.2165749e-04, -1.2994988e-04, -1.2423823e-04,
         -3.2156322e-04],
        [-1.6304413e-04, -1.6415064e-04, -1.6482500e-04, -1.6351035e-04,
          8.4395942e-06],
        [ 1.6221512e-04,  1.6363618e-04,  1.5648371e-04,  1.6195382e-04,
         -9.4930510e-06],
        [-1.5283289e

In [19]:
from flax import serialization
bytes_output = serialization.to_bytes(params)
dict_output = serialization.to_state_dict(params)
print('Dict output')
print(dict_output)
print('Bytes output')
print(bytes_output)

Dict output
{'params': {'bias': Array([-1.4555806, -2.0278268,  2.0790968,  1.2186172, -0.9980628],      dtype=float32), 'kernel': Array([[ 1.0098493 ,  0.18933125,  0.04458343, -0.92802316,  0.34783497],
       [ 1.7298489 ,  0.9879658 ,  1.1640443 ,  1.1006037 , -0.10651544],
       [-1.2029523 ,  0.2863497 ,  1.4156082 ,  0.11870176, -1.3141391 ],
       [-1.1941417 , -0.18958248,  0.03414451,  1.3169445 ,  0.0805987 ],
       [ 0.13851093,  1.3712997 , -1.318749  ,  0.53152126, -2.2405198 ],
       [ 0.5629417 ,  0.8122362 ,  0.31753275,  0.53455   ,  0.90499985],
       [-0.37926182,  1.7410471 ,  1.0790585 , -0.5039784 ,  0.9282756 ],
       [ 0.9706411 , -1.3153212 ,  0.3368311 ,  0.80993503, -1.2018685 ],
       [ 1.0194358 , -0.6202532 ,  1.0818852 , -1.8389667 , -0.45808962],
       [-0.6436615 ,  0.4566898 , -1.1329143 , -0.6853882 ,  0.16831206]],      dtype=float32)}}
Bytes output
b'\x81\xa6params\x82\xa4bias\xc7!\x01\x93\x91\x05\xa7float32\xc4\x14wP\xba\xbf\xea\xc7\x01\xc