**S04P01_flax_basics.ipynb**

Arz

2024 APR 25 (THU)

reference:
https://flax.readthedocs.io/en/latest/guides/flax_fundamentals/flax_basics.html

# setting up out environment

In [1]:
import jax
import jax.numpy as jnp

In [2]:
import flax
from flax import linen as nn

In [3]:
from typing import Any, Callable, Sequence

# linear regression with Flax

linear regression can also be written as a single dense neural network layer.

## ex) dense layer

models (including layers) are subclasses of flax.linen.Module class.

https://flax.readthedocs.io/en/v0.5.3/_autosummary/flax.linen.Dense.html

In [4]:
model = nn.Dense(features=3)  # output dimension is 3

### model parameters & initialization

⚠️ parameters are not stored with the models themselves. 

you need to initialize parameters by calling the **init** function, using a PRNGKey and dummy input data.

In [5]:
key1, key2 = jax.random.split(jax.random.key(0))

x = jax.random.normal(key1, (7,))  # dummy input data
params = model.init(key2, x)  # initialize model parameters

2024-04-25 13:03:45.038654: W external/xla/xla/service/gpu/nvptx_compiler.cc:718] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.4.131). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [7]:
# check dimensions
jax.tree_util.tree_map(lambda x: x.shape, params)

{'params': {'bias': (3,), 'kernel': (7, 3)}}

### forward propagation

In [8]:
model.apply(params, x)

Array([ 1.3483415 , -0.4280271 , -0.10713735], dtype=float32)

### gradient descent