# Training a simple MLP with FLAX

In [2]:
# install flax if haven't done so
!pip install -q flax

In [2]:
# a toy example using flax
import numpy as np

import jax
import jax.numpy as jnp

from flax import linen as nn
from flax.training import train_state

import optax # optimization library for JAX from deepmind

In [3]:
# create a dense layer
model = nn.Dense(features=5)

# Psedo random number generation
key = jax.random.PRNGKey(0)
print(key)
key1, key2 = jax.random.split(key)
print(key, key1, key2, '\n')

# initialize input
x = jax.random.normal(key1, (10,))
params = model.init(key2, x)
print("x: {}\n\nparams: {}\n".format(x, params.keys()))

jax.tree_util.tree_map(lambda x: x.shape, params)

[0 0]
[0 0] [4146024105  967050713] [2718843009 1272950319] 

x: [-2.6105583   0.03385283  1.0863333  -1.4802988   0.48895672  1.062516
  0.54174834  0.0170228   0.2722685   0.30522448]

params: frozen_dict_keys(['params'])



FrozenDict({
    params: {
        bias: (5,),
        kernel: (10, 5),
    },
})

In [25]:
y = model.apply(params, x)
print(y)

[-1.3721193   0.61131495  0.6442836   2.2192965  -1.1271116 ]


## Define a MLP module

### (1) Define a MLP with FLAX's pre-defined layers

In [10]:
import flax.linen as nn
from flax.core import freeze, unfreeze
from typing import Sequence

class MLP(nn.Module):
    features: Sequence[int]

    def setup(self):
        self.layers = [nn.Dense(feat) for feat in self.features]
        # self.layer1 = nn.Dense(feat1)
        self.num_layers = len(self.layers)
    
    def __call__(self, inputs):
        x = inputs
        for i, layer in enumerate(self.layers):
            x = layer(x)
            if i != self.num_layers-1:
                x = nn.relu(x)
        return x

In [11]:
import jax.random as random
key1, key2 = random.split(random.PRNGKey(0), 2)
x = random.uniform(key1, (4,4))

model = MLP(features=[2,3,4,5])
params = model.init(key2, x)
y = model.apply(params, x)

print('initialized param shapes:\n', 
      jax.tree_util.tree_map(jnp.shape, unfreeze(params)))

print('output:\n', y)

initialized param shapes:
 {'params': {'layers_0': {'bias': (2,), 'kernel': (4, 2)}, 'layers_1': {'bias': (3,), 'kernel': (2, 3)}, 'layers_2': {'bias': (4,), 'kernel': (3, 4)}, 'layers_3': {'bias': (5,), 'kernel': (4, 5)}}}
output:
 [[0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]]


### (2) Customize layers

In [43]:
from typing import Callable
from jax import lax # lib for primitive ops

class CustomDense(nn.Module):
    features: int
    kernel_init: Callable = nn.initializers.lecun_normal()
    bias_init: Callable = nn.initializers.zeros

    @nn.compact # about this decorator: https://github.com/google-research/vision_transformer/issues/118
    def __call__(self, inputs):
        kernel = self.param('kernel',
                            self.kernel_init,
                            (inputs.shape[-1], self.features)
        )
        y = lax.dot_general(inputs, kernel,
                        (((inputs.ndim - 1,), (0,)), ((), ())),)
        bias = self.param('bias', self.bias_init, (self.features,))
        y = y + bias
        return y

In [44]:
key1, key2 = random.split(random.PRNGKey(0), 2)
x = random.uniform(key1, (4,4))

model = CustomDense(features=3)
params = model.init(key2, x)
y = model.apply(params, x)

print('initialized parameters:\n', params)
print('output:\n', y)

initialized parameters:
 FrozenDict({
    params: {
        kernel: DeviceArray([[ 0.61506   , -0.22728713,  0.6054702 ],
                     [-0.29617992,  1.1232013 , -0.879759  ],
                     [-0.35162622,  0.3806491 ,  0.6893246 ],
                     [-0.1151355 ,  0.04567898, -1.091212  ]], dtype=float32),
        bias: DeviceArray([0., 0., 0.], dtype=float32),
    },
})
output:
 [[-0.02996203  1.102088   -0.6660265 ]
 [-0.31092793  0.63239413 -0.53678817]
 [ 0.01424009  0.9424717  -0.63561463]
 [ 0.3681896   0.3586519  -0.00459218]]


## Gradient Descent

### (1) Gradient descent without FLAX (pure JAX)

In [21]:
# Gradient descent for linear regression
n_samples = 20
x_dim, y_dim = 10, 5

# create parameters to optimize
key1, key2 = random.split(random.PRNGKey(0))
x = random.normal(key1, (10,)) # Dummy input

model = nn.Dense(features=5)
params = model.init(key2, x) # Initialization call
print(jax.tree_util.tree_map(lambda x: x.shape, params)) # Checking output shapes

# create optimized parameters 
key = random.PRNGKey(0)
k1, k2 = random.split(key)
W = random.normal(k1, (x_dim, y_dim))
b = random.normal(k2, (y_dim,))
true_params = freeze({
    'params':{
        'bias': b,
        'kernel': W
    }
})

# generate GT data for training
key_sample, key_noise = random.split(k1)
x_samples = random.normal(key_sample, (n_samples, x_dim))
rand_noise = 0.1*random.normal(key_noise, (n_samples, y_dim))
y_samples = jnp.dot(x_samples, W) + b + rand_noise

print('x_shape: {}, y_shape: {}\n'.format(x_samples.shape, y_samples.shape))


FrozenDict({
    params: {
        bias: (5,),
        kernel: (10, 5),
    },
})
x_shape: (20, 10), y_shape: (20, 5)



In [23]:
# function to get loss
@jax.jit
def mse(params, x_batch, y_batch):
    def sqrt_error(x, y):
        pred = model.apply(params, x)
        return jnp.inner(y-pred, y-pred)/2.0
    # clever use of vmap, jax's cool feature
    return jnp.mean(jax.vmap(sqrt_error)(x_batch, y_batch), axis=0)

# function to get gradient
loss_grad_fn = jax.value_and_grad(mse)

# function to update params
@jax.jit
def update_params(params, lr, grads):
    params = jax.tree_util.tree_map(
        lambda p, g: p - lr*g, params, grads
    )
    return params


loss with true W,b: 0.023639793


In [37]:
# test getting loss
lr = 0.3
loss = mse(true_params, x_samples, y_samples)
print('loss with true W,b:', loss)

# test getting gradients
loss, grads = loss_grad_fn(params, x_samples, y_samples)
#print('loss:{}\n gradient: {}\n'.format(loss, grads))

# test updating params
#print('original params: {}'.format(params))
params = update_params(params, lr, grads)
#print('updated params: {}'.format(params))

# perform back-propagation for multiple iterations
for i in range(0, 31):
    loss, grads = loss_grad_fn(params, x_samples, y_samples)
    params = update_params(params, lr, grads)
    if i % 10 == 0:
        print('loss step {}: {}'.format(i, loss))






loss with true W,b: 0.023639793
loss step 0: 0.011568372137844563
loss step 10: 0.011568372137844563
loss step 20: 0.011568372137844563
loss step 30: 0.011568372137844563


### (2) Gradient Descent with *Optax* (lib from deepmind)

In [38]:
import optax
optimizer = optax.sgd(learning_rate=lr)
opt_state = optimizer.init(params)

loss_grad_fn = jax.value_and_grad(mse)

for i in range(31):
    loss, grads = loss_grad_fn(params, x_samples, y_samples)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)

    if i % 10 == 0:
        print('loss at step {}: {}'.format(i, loss))

loss at step 0: 0.011568371206521988
loss at step 10: 0.011568371206521988
loss at step 20: 0.011568371206521988
loss at step 30: 0.011568371206521988


### Serialize params for export and import

In [36]:
# serialize params for export
from flax import serialization
bytes_output = serialization.to_bytes(params)   #choice1
dict_output = serialization.to_state_dict(params) #choice2

# import params
#serialization.from_bytes(params, bytes_output)  # choice1
serialization.from_state_dict(params, dict_output)   # choice2


FrozenDict({
    params: {
        bias: DeviceArray([-1.4540124 , -2.0262275 ,  2.0806599 ,  1.2201837 ,
                     -0.99645793], dtype=float32),
        kernel: DeviceArray([[ 1.0106655 ,  0.19014445,  0.04533757, -0.9272265 ,
                       0.3472048 ],
                     [ 1.732027  ,  0.99013054,  1.1662259 ,  1.102798  ,
                      -0.10575476],
                     [-1.2009128 ,  0.28837118,  1.4176372 ,  0.12073042,
                      -1.3132594 ],
                     [-1.194495  , -0.18993127,  0.03379178,  1.3165966 ,
                       0.07995866],
                     [ 0.14103451,  1.3738064 , -1.3162082 ,  0.5340303 ,
                      -2.2396488 ],
                     [ 0.5643062 ,  0.8136104 ,  0.31888482,  0.53592736,
                       0.903514  ],
                     [-0.3794808 ,  1.7408438 ,  1.0788052 , -0.5041857 ,
                       0.9286824 ],
                     [ 0.97013855, -1.3158665 ,  0.33630857,  0.8