**S04P01_flax_basics.ipynb**

Arz

2024 APR 25 (THU)

reference:

- https://flax.readthedocs.io/en/latest/guides/flax_fundamentals/flax_basics.html
- https://github.com/Tershire/JAX_Study/blob/master/S02/S02P01_tutorial_jax_as_accelerated_numpy.ipynb

# setting up out environment

In [1]:
import jax
import jax.numpy as jnp

In [2]:
import flax
from flax import linen as nn

In [3]:
from typing import Any, Callable, Sequence

# linear regression with Flax

linear regression can also be written as a single dense neural network layer.

## ex) dense layer

models (including layers) are subclasses of flax.linen.Module class.

https://flax.readthedocs.io/en/v0.5.3/_autosummary/flax.linen.Dense.html

In [4]:
model = nn.Dense(features=3)  # output dimension is 3

### model parameters & initialization

⚠️ parameters are not stored with the models themselves. 

you need to initialize parameters by calling the **init** function, using a PRNGKey and dummy input data.

- dummy input data informs the model of the input number of features.

In [5]:
key1, key2 = jax.random.split(jax.random.key(0))

x = jax.random.normal(key1, (7,))  # dummy input data
params = model.init(key2, x)  # initialize model parameters

2024-05-26 15:45:12.189172: W external/xla/xla/service/gpu/nvptx_compiler.cc:718] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.4.131). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [6]:
# check dimensions
jax.tree_util.tree_map(lambda x: x.shape, params)

{'params': {'bias': (3,), 'kernel': (7, 3)}}

Flax is row-based system, so a vector is represented as a row.

'kernel' (W) has shape (#input, #output)

### forward propagation

In [7]:
model.apply(params, x)

Array([ 1.3483415 , -0.4280271 , -0.10713735], dtype=float32)

### gradient descent

In [8]:
# data setup & generation
num_samples = 30
x_dim = 7  # input
y_dim = 3  # output

#
key = jax.random.key(0)
_, *subkeys = jax.random.split(key, 3)
key_W, key_b = subkeys[0], subkeys[1]

W = jax.random.normal(key_W, (x_dim, y_dim))
b = jax.random.normal(key_b, (y_dim,))

original_params = flax.core.freeze({"params": {"bias": b, "kernel": W}})

#
_, *subkeys = jax.random.split(key_W, 3)
key_sample, key_noise = subkeys[0], subkeys[1]

x_samples = jax.random.normal(key_sample, (num_samples, x_dim))
noise = 0.1 * jax.random.normal(key_noise, (num_samples, y_dim))
y_samples = jnp.dot(x_samples, W) + b + noise

print("x shape:", x_samples.shape, "y shape:", y_samples.shape)

x shape: (30, 7) y shape: (30, 3)


In [9]:
# loss function
def mean_squared_error(params, x_batch, y_batch): 
    # for a single pair (x, y)
    def squared_error(x, y):    
        y_pred = model.apply(params, x)
        return jnp.inner(y - y_pred, y - y_pred)

    # vectorize
    return jnp.mean(jax.vmap(squared_error)(x_batch, y_batch), axis=0)    

In [10]:
loss_grad_function = jax.value_and_grad(mean_squared_error)

unlike the JAX example, this update_params() takes in grads.

in the former case, jax.grad() was inside the function. why (?)

In [11]:
# param.s update
@jax.jit
def update_params(params, alpha, grads):
    """
    alpha: learning rate
    """
    params = jax.tree_util.tree_map(lambda p, g: p - alpha * g, params, grads)
    return params

In [12]:
# train
alpha = 0.3

max_step = 100
for i in range(max_step):
    loss, grads = loss_grad_function(params, x_samples, y_samples)
    params = update_params(params, alpha, grads)

    if i % 10 == 0 or i == max_step - 1:
        print(f"loss at step {i:3.0F}: {loss:.12F}")

loss at step   0: 33.618389129639
loss at step  10: 0.044064313173
loss at step  20: 0.027380581945
loss at step  30: 0.027127481997
loss at step  40: 0.027123507112
loss at step  50: 0.027123449370
loss at step  60: 0.027123443782
loss at step  70: 0.027123438194
loss at step  80: 0.027123436332
loss at step  90: 0.027123436332
loss at step  99: 0.027123434469


### optimizing with Optax

- 1. choose optimization method (ex. adam).
- 2. initialize optimizer state given model parameters.
- 3. compute loss gradients using *jax.value_and_grad()*
- 4. at every iteration,
 
        - call *update()* to update optimizer state and model parameters.
        - call *apply_updates()* to update model parameters.

In [20]:
import optax

In [21]:
# initialize model parameters
params = model.init(key2, x)

In [22]:
# create optimizer
optimizer = optax.adam(learning_rate=alpha)
optimizer_state = optimizer.init(params)

In [23]:
# train
max_step = 100
for i in range(max_step):
    loss, grads = loss_grad_function(params, x_samples, y_samples)
    params_update, optimizer_state = optimizer.update(grads, optimizer_state)
    params = optax.apply_updates(params, params_update)
    
    if i % 10 == 0 or i == max_step - 1:
        print(f"loss at step {i:3.0F}: {loss:.12F}")

loss at step   0: 33.618389129639
loss at step  10: 4.037739753723
loss at step  20: 1.026890873909
loss at step  30: 0.303563058376
loss at step  40: 0.205362066627
loss at step  50: 0.078522302210
loss at step  60: 0.046105630696
loss at step  70: 0.034731782973
loss at step  80: 0.029582021758
loss at step  90: 0.027837540954
loss at step  99: 0.027524823323


### serializing the result

{save, load} model parameters.

- **save**

In [27]:
saved_params_as_dict = flax.serialization.to_state_dict(params)
print(saved_params_as_dict)

{'params': {'bias': Array([ 0.09807936, -1.1808372 , -0.2617658 ], dtype=float32), 'kernel': Array([[-1.8411863 , -0.50547993,  0.55073863],
       [-1.2407506 , -0.6748584 , -0.7191346 ],
       [ 1.1142238 , -2.4242043 ,  0.07755796],
       [ 0.46512845,  0.80414766, -1.4551915 ],
       [-1.2403113 , -2.006448  ,  0.5853151 ],
       [-0.9453332 ,  0.9392538 ,  0.40514505],
       [ 0.2739672 , -0.30546483,  0.04342157]], dtype=float32)}}


In [28]:
saved_params_as_bytes = flax.serialization.to_bytes(params)
print(saved_params_as_bytes)

b'\x81\xa6params\x82\xa4bias\xc7\x19\x01\x93\x91\x03\xa7float32\xc4\x0c\xd5\xdd\xc8=\xac%\x97\xbf+\x06\x86\xbe\xa6kernel\xc7b\x01\x93\x92\x07\x03\xa7float32\xc4T\xfe\xab\xeb\xbf"g\x01\xbf5\xfd\x0c?\xea\xd0\x9e\xbf\x85\xc3,\xbf5\x198\xbf\xe3\x9e\x8e?*&\x1b\xc0\xb5\xd6\x9e=Q%\xee>\x9f\xdcM?\xb7C\xba\xbf\x85\xc2\x9e\xbf\xa5i\x00\xc06\xd7\x15?[\x01r\xbf\xf0rp?,o\xcf>nE\x8c>\xe3e\x9c\xbe\xd2\xda1='


- **load**

need to provide a parameter template as the first argument, so that the load function can recognize the parameters structure.

here, **params** will not be modified but will provide the template.

In [30]:
loaded_params_from_dict = flax.serialization.from_state_dict(params, saved_params_as_dict)
print(loaded_params_from_dict)

{'params': {'bias': Array([ 0.09807936, -1.1808372 , -0.2617658 ], dtype=float32), 'kernel': Array([[-1.8411863 , -0.50547993,  0.55073863],
       [-1.2407506 , -0.6748584 , -0.7191346 ],
       [ 1.1142238 , -2.4242043 ,  0.07755796],
       [ 0.46512845,  0.80414766, -1.4551915 ],
       [-1.2403113 , -2.006448  ,  0.5853151 ],
       [-0.9453332 ,  0.9392538 ,  0.40514505],
       [ 0.2739672 , -0.30546483,  0.04342157]], dtype=float32)}}


In [31]:
loaded_params_from_bytes = flax.serialization.from_bytes(params, saved_params_as_bytes)
print(loaded_params_from_bytes)

{'params': {'bias': array([ 0.09807936, -1.1808372 , -0.2617658 ], dtype=float32), 'kernel': array([[-1.8411863 , -0.50547993,  0.55073863],
       [-1.2407506 , -0.6748584 , -0.7191346 ],
       [ 1.1142238 , -2.4242043 ,  0.07755796],
       [ 0.46512845,  0.80414766, -1.4551915 ],
       [-1.2403113 , -2.006448  ,  0.5853151 ],
       [-0.9453332 ,  0.9392538 ,  0.40514505],
       [ 0.2739672 , -0.30546483,  0.04342157]], dtype=float32)}}


# defining your own models

## ex) multi-layer perceptron (MLP)

In [40]:
class Explicit_MLP(nn.Module):
    features: Sequence[int]  # sequence of output feature dimension per layer

    def setup(self):
        self.layers = [nn.Dense(feature) for feature in self.features]

    def __call__(self, inputs):
        x = inputs
        for i, layer in enumerate(self.layers):
            x = layer(x)
            if i is not len(self.layers) - 1:
                x = nn.relu(x)
        return x

In [41]:
# input
_, *subkeys = jax.random.split(jax.random.key(0), 3)
key1, key2 = subkeys[0], subkeys[1]

num_samples = 10
x_dim = 7
x = jax.random.uniform(key1, (num_samples, x_dim))

# model
model = Explicit_MLP(features=[9, 5, 3])
params = model.init(key2, x)

# output
y = model.apply(params, x)

print("params shape:\n", jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)))  # (?)
print("y:\n", y)

params shape:
 {'params': {'layers_0': {'bias': (9,), 'kernel': (7, 9)}, 'layers_1': {'bias': (5,), 'kernel': (9, 5)}, 'layers_2': {'bias': (3,), 'kernel': (5, 3)}}}
y:
 [[-0.10174     0.468905    0.2727127 ]
 [-0.40046442  0.5607383   0.3304994 ]
 [-0.23100257  0.57434875  0.2876985 ]
 [-0.3043432   0.463985    0.27915537]
 [-0.21458422  0.5588921   0.3658545 ]
 [-0.5424561   0.9484465   0.5600072 ]
 [-0.19780068  0.3608388   0.17913571]
 [-0.2271311   0.55443084  0.3432787 ]
 [-0.13232163  0.49851972  0.32358307]
 [-0.51897645  0.73706603  0.43027398]]
