<a href="https://colab.research.google.com/github/afairley/ColaboratoryNotebooks/blob/main/FlaxBasics.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install --upgrade  -q pip jax jaxlib
!pip install --upgrade -q git+https://github.com/google/flax.git
import jax
from typing import Any, Callable, Sequence
from jax import random, numpy as jnp
import flax
from flax import linen as nn

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m11.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m32.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.2/79.2 MB[0m [31m7.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Building wheel for flax (pyproject.toml) ... [?25l[?25hdone
[0m

In [None]:
model = nn.Dense(features=5)
key1, key2 = random.split(random.key(0))
x = random.normal(key1, (10,))
params = model.init(key2,x)
jax.tree_util.tree_map(lambda x: x.shape,params)
model.apply(params,x)

n_samples = 20
x_dim = 10
y_dim = 5

nextKey = random.key(0)
k1, k2 = random.split(nextKey)
W = random.normal(k1, (x_dim, y_dim))
b = random.normal(k2,(y_dim,))

true_params = flax.core.freeze({'params':{'bias': b, 'kernel': W}})
key_sample, key_noise = random.split(k1)
x_samples = random.normal(key_sample, (n_samples, x_dim))
y_samples = jnp.dot(x_samples,W) + b + 0.1 *\
 random.normal(key_noise,(n_samples, y_dim))
print('x shape:', x_samples.shape, '; y shape:', y_samples.shape)
print('x:', x_samples, '; y:', y_samples)

In [16]:
#@jax.jit
def mean_squared_error(params, model, x_batched, y_batched):
  def squared_error(x, y):
    pred = model.apply(params, x)
    return jnp.inner(y-pred, y-pred) / 2.0
  return jnp.mean(jax.vmap(squared_error)(x_batched, y_batched), axis=0)

In [20]:
learning_rate = 0.3
print('Loss for "true" W, b : ', mean_squared_error(true_params, model, x_samples, y_samples))
loss_grad_fn = jax.value_and_grad(mean_squared_error)

@jax.jit
def update_params(params, learning_rate, grads):
  params = jax.tree_util.tree_map(
      lambda p, g: p - learning_rate * g, params, grads)
  return params
print("Reinitializing parameters")
params = model.init(key2,x)
print("\nParams\n", params, "\n")
for i in range(101):
  loss_val, grads = loss_grad_fn(params, model, x_samples, y_samples)
  params = update_params(params, learning_rate, grads)
  if i % 10 == 0:
    print(f'Loss step {i}:', loss_val)
print("\nParams\n", params, "\n")

Loss for "true" W, b :  0.023639789

Params
 {'params': {'kernel': Array([[ 2.35571519e-01, -1.71652585e-01, -4.45728786e-02,
        -4.68043566e-01,  4.54595268e-01],
       [-6.87736452e-01,  3.67835373e-01, -1.79262087e-01,
         1.29276231e-01, -2.42580160e-01],
       [ 2.02303097e-01, -2.49465615e-01,  2.74955630e-01,
         4.73488361e-01, -1.98002517e-01],
       [ 2.74478316e-01, -1.21369645e-01, -2.25361675e-01,
        -4.78193641e-01, -9.63979885e-02],
       [-6.19886033e-02, -1.72743499e-01,  2.96945305e-04,
        -7.17593372e-01,  2.00894207e-01],
       [-5.60321152e-01,  3.27208370e-01,  1.06281497e-01,
         1.28758654e-01,  1.16973236e-01],
       [ 1.82218999e-01,  1.11444063e-01, -1.62924141e-01,
         3.24953087e-02, -1.67053342e-01],
       [ 4.31294113e-01,  2.08004564e-01,  1.47714227e-01,
        -8.51502866e-02, -1.26487061e-01],
       [ 3.29497308e-01,  1.08470365e-01, -4.01340067e-01,
         1.66956007e-01,  5.74723601e-01],
       [-3.8474

In [None]:
import optax
tx = optax.adam(learning_rate=learning_rate)
