# Guides
https://flax.readthedocs.io/en/latest/guides/index.html

## JAX for the Impatient
### Managing randomness

In [6]:
import jax
from jax import numpy as jnp

In [2]:
key = jax.random.PRNGKey(0)
key

DeviceArray([0, 0], dtype=uint32)

In [3]:
for _ in range(3):
    print(f'key: {key} -> {jax.random.normal(key, (1,))}')

key: [0 0] -> [-0.20584226]
key: [0 0] -> [-0.20584226]
key: [0 0] -> [-0.20584226]


In [4]:
key, subkey = jax.random.split(key)
print(key)
print(subkey)

[4146024105  967050713]
[2718843009 1272950319]


In [5]:
key, *subkeys = jax.random.split(key, 4)
key, subkeys

(DeviceArray([3306097435, 3899823266], dtype=uint32),
 [DeviceArray([147607341, 367236428], dtype=uint32),
  DeviceArray([2280136339, 1907318301], dtype=uint32),
  DeviceArray([ 781391491, 1939998335], dtype=uint32)])

### Refining a bit with pytrees

In [7]:
t = [1, {'k1': 2, 'k2': (3, 4)}, 5]
jax.tree_util.tree_map(lambda x: x*x, t)

[1, {'k1': 4, 'k2': (9, 16)}, 25]

In [10]:
jax.tree_util.tree_map(lambda x: x**3, t)

[1, {'k1': 8, 'k2': (27, 64)}, 125]

In [12]:
t2 = jax.tree_util.tree_map(lambda x: x*x, t)
print(t2)

print(jax.tree_util.tree_map(lambda x, y: x+y, t, t2))

[1, {'k1': 4, 'k2': (9, 16)}, 25]
[2, {'k1': 6, 'k2': (12, 20)}, 30]


## Flax Basics

In [14]:
from typing import Any, Callable, Sequence
from flax.core import freeze, unfreeze
from flax import linen as nn

### Linear regression with Flax

In [15]:
model = nn.Dense(features=5)

In [16]:
key1, key2 = jax.random.split(jax.random.PRNGKey(0))
x = jax.random.normal(key1, (10, ))
params = model.init(key2, x)
jax.tree_util.tree_map(lambda x: x.shape, params)

FrozenDict({
    params: {
        bias: (5,),
        kernel: (10, 5),
    },
})

In [17]:
try:
    params['new_key'] = jnp.ones((2, 2))
except ValueError as e:
    print('Error', e)

Error FrozenDict is immutable.


In [18]:
model.apply(params, x)

DeviceArray([-1.3721197 ,  0.61131513,  0.6442838 ,  2.2192965 ,
             -1.1271117 ], dtype=float32)

In [21]:
n_samples = 20
x_dim = 10
y_dim = 5

key = jax.random.PRNGKey(0)
k1, k2 = jax.random.split(key)
W = jax.random.normal(k1, (x_dim, y_dim))
b = jax.random.normal(k2, (y_dim,))

true_params = freeze({'params': {'bias': b, 'kernel': W}})

key_sample, key_noise = jax.random.split(k1)
x_samples = jax.random.normal(key_sample, (n_samples, x_dim))
y_samples = jnp.dot(x_samples, W) + b + 0.1 * jax.random.normal(key_noise, (n_samples, y_dim))
print('x_shape:', x_samples.shape, '; y shape:', y_samples.shape)

x_shape: (20, 10) ; y shape: (20, 5)


In [22]:
@jax.jit
def mse(params, x_batches, y_batches):
    def squared_error(x, y):
        pred = model.apply(params, x)
        return jnp.inner(y-pred, y-pred) / 2.0
    return jnp.mean(jax.vmap(squared_error)(x_batches, y_batches), axis=0)

In [23]:
learning_rate = 0.3
print('Loss for "true" W,b: ', mse(true_params, x_samples, y_samples))
loss_grad_fn = jax.value_and_grad(mse)

@jax.jit
def update_params(params, learning_rate, grads):
    params = jax.tree_util.tree_map(
        lambda p, g: p - learning_rate * g, params, grads)
    return params

for i in range(101):
    loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
    params = update_params(params, learning_rate, grads)
    if i % 10 == 0:
        print(f'Loss step {i}: ', loss_val)
        

Loss for "true" W,b:  0.023639798
Loss step 0:  35.343876
Loss step 10:  0.5143469
Loss step 20:  0.11384161
Loss step 30:  0.03932675
Loss step 40:  0.019916205
Loss step 50:  0.014209128
Loss step 60:  0.012425651
Loss step 70:  0.0118503915
Loss step 80:  0.011661774
Loss step 90:  0.011599411
Loss step 100:  0.011578695


### Optimizing with Optax

In [24]:
import optax

In [25]:
tx = optax.sgd(learning_rate=learning_rate)
opt_state = tx.init(params)
loss_grad_fn = jax.value_and_grad(mse)

In [27]:
for i in range(101):
    loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
    updates, opt_state = tx.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    
    if i % 10 == 0:
        print(f'Loss step {i}: {loss_val}')

Loss step 0: 0.011577622964978218
Loss step 10: 0.011571443639695644
Loss step 20: 0.011569392867386341
Loss step 30: 0.011568702757358551
Loss step 40: 0.011568485759198666
Loss step 50: 0.011568406596779823
Loss step 60: 0.011568374000489712
Loss step 70: 0.011568362824618816
Loss step 80: 0.011568366549909115
Loss step 90: 0.011568361893296242
Loss step 100: 0.01156836748123169


### Serializing the result

In [28]:
from flax import serialization

In [35]:
bytes_output = serialization.to_bytes(params)
dict_output = serialization.to_state_dict(params)

print('Dict output')
print(dict_output)

print('Bytes output')
print(bytes_output)

Dict output
{'params': {'bias': DeviceArray([-1.4540141 , -2.0262318 ,  2.080659  ,  1.22018   ,
             -0.99645686], dtype=float32), 'kernel': DeviceArray([[ 1.0106674 ,  0.19014898,  0.0453391 , -0.9272226 ,
               0.3472035 ],
             [ 1.7320229 ,  0.9901195 ,  1.1662225 ,  1.1027888 ,
              -0.10575178],
             [-1.2009125 ,  0.2883722 ,  1.4176373 ,  0.12073128,
              -1.3132595 ],
             [-1.1944962 , -0.18993425,  0.03379067,  1.3165945 ,
               0.07995945],
             [ 0.14102836,  1.3737915 , -1.3162129 ,  0.53401744,
              -2.239645  ],
             [ 0.56430227,  0.81360054,  0.31888163,  0.5359191 ,
               0.9035165 ],
             [-0.37948614,  1.7408308 ,  1.0788013 , -0.5041968 ,
               0.92868567],
             [ 0.970138  , -1.3158677 ,  0.33630812,  0.80941117,
              -1.2024575 ],
             [ 1.019825  , -0.61982715,  1.0822719 , -1.8385578 ,
              -0.45790663],
    

In [32]:
for k, v in dict_output['params'].items():
    print(k, v)

bias [-1.4540141  -2.0262318   2.080659    1.22018    -0.99645686]
kernel [[ 1.0106674   0.19014898  0.0453391  -0.9272226   0.3472035 ]
 [ 1.7320229   0.9901195   1.1662225   1.1027888  -0.10575178]
 [-1.2009125   0.2883722   1.4176373   0.12073128 -1.3132595 ]
 [-1.1944962  -0.18993425  0.03379067  1.3165945   0.07995945]
 [ 0.14102836  1.3737915  -1.3162129   0.53401744 -2.239645  ]
 [ 0.56430227  0.81360054  0.31888163  0.5359191   0.9035165 ]
 [-0.37948614  1.7408308   1.0788013  -0.5041968   0.92868567]
 [ 0.970138   -1.3158677   0.33630812  0.80941117 -1.2024575 ]
 [ 1.019825   -0.61982715  1.0822719  -1.8385578  -0.45790663]
 [-0.6438441   0.45648792 -1.1331053  -0.6855687   0.17010677]]


In [33]:
type(dict_output)

dict

In [36]:
serialization.from_bytes(params, bytes_output)

FrozenDict({
    params: {
        bias: array([-1.4540141 , -2.0262318 ,  2.080659  ,  1.22018   , -0.99645686],
              dtype=float32),
        kernel: array([[ 1.0106674 ,  0.19014898,  0.0453391 , -0.9272226 ,  0.3472035 ],
               [ 1.7320229 ,  0.9901195 ,  1.1662225 ,  1.1027888 , -0.10575178],
               [-1.2009125 ,  0.2883722 ,  1.4176373 ,  0.12073128, -1.3132595 ],
               [-1.1944962 , -0.18993425,  0.03379067,  1.3165945 ,  0.07995945],
               [ 0.14102836,  1.3737915 , -1.3162129 ,  0.53401744, -2.239645  ],
               [ 0.56430227,  0.81360054,  0.31888163,  0.5359191 ,  0.9035165 ],
               [-0.37948614,  1.7408308 ,  1.0788013 , -0.5041968 ,  0.92868567],
               [ 0.970138  , -1.3158677 ,  0.33630812,  0.80941117, -1.2024575 ],
               [ 1.019825  , -0.61982715,  1.0822719 , -1.8385578 , -0.45790663],
               [-0.6438441 ,  0.45648792, -1.1331053 , -0.6855687 ,  0.17010677]],
              dtype=float32

### Define your own models

In [39]:
class ExplicitMLT(nn.Module):
    features: Sequence[int]
    
    def setup(self):
        self.layers = [nn.Dense(feat) for feat in self.features]
        
    def __call__(self, inputs):
        x = inputs
        for i, lyr in enumerate(self.layers):
            x = lyr(x)
            if i != len(self.layers) - 1:
                x = nn.relu(x)
        return x
    
key1, key2 = jax.random.split(jax.random.PRNGKey(0), 2)
x = jax.random.uniform(key1, (4, 4))

model = ExplicitMLT(features=[3, 4, 5])
params = model.init(key2, x)
y = model.apply(params, x)

print('initialized parameter shape:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(params)))
print('output:\n', y)

initialized parameter shape:
 {'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}
output:
 [[ 0.          0.          0.          0.          0.        ]
 [ 0.00723787 -0.00810345 -0.0255093   0.02151708 -0.01261237]
 [ 0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.        ]]


In [40]:
try:
    y = model(x)
except AttributeError as e:
    print(e)

"ExplicitMLT" object has no attribute "layers"


In [43]:
class SimpleMLP(nn.Module):
    features: Sequence[int]
    
    @nn.compact
    def __call__(self, inputs):
        x = inputs
        for i, feat in enumerate(self.features):
            x = nn.Dense(feat, name=f'layer_{i}')(x)
            if i != len(self.features) - 1:
                x = nn.relu(x)
                
        return x
    
key1, key2 = jax.random.split(jax.random.PRNGKey(0), 2)
x = jax.random.uniform(key1, (4, 4))

model = SimpleMLP(features=[3, 4, 5])
params = model.init(key2, x)
y = model.apply(params, x)

print('initialized parameter shape:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(params)))
print('output:\n', y)

initialized parameter shape:
 {'params': {'layer_0': {'bias': (3,), 'kernel': (4, 3)}, 'layer_1': {'bias': (4,), 'kernel': (3, 4)}, 'layer_2': {'bias': (5,), 'kernel': (4, 5)}}}
output:
 [[-0.48866612  0.5616314  -0.7335918  -0.7830883   0.00913267]
 [-0.25978276  0.31105274 -0.3998276  -0.4330669  -0.00686361]
 [-0.32860482  0.40832612 -0.56914747 -0.5940301   0.01978956]
 [-0.0976381   0.40497297 -1.0262899  -0.8247814   0.1343164 ]]
